diff --git a/inference/infer_tool.py b/inference/infer_tool.py index 318f81a..75b6aca 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -150,7 +150,7 @@ class Svc(object): wav, sr = librosa.load(in_path, sr=self.target_sample) if F0_mean_pooling == True: - f0, uv = utils.compute_f0_uv_torchcrepe(torch.FloatTensor(wav), sampling_rate=self.target_sample, hop_length=self.hop_size) + f0, uv = utils.compute_f0_uv_torchcrepe(torch.FloatTensor(wav), sampling_rate=self.target_sample, hop_length=self.hop_size,device=self.dev) if f0_filter and sum(f0) == 0: raise F0FilterException("未检测到人声") f0 = torch.from_numpy(f0).float() diff --git a/modules/crepe.py b/modules/crepe.py index d2cb498..5a5076b 100644 --- a/modules/crepe.py +++ b/modules/crepe.py @@ -89,7 +89,7 @@ class BasePitchExtractor: if self.keep_zeros: return f0 - vuv_vector = np.zeros_like(f0) + vuv_vector = torch.zeros_like(f0) vuv_vector[f0 > 0.0] = 1.0 vuv_vector[f0 <= 0.0] = 0.0 @@ -107,6 +107,7 @@ class BasePitchExtractor: # 大概可以用 torch 重写? f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + vuv_vector = vuv_vector.cpu().numpy() vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) return f0,vuv_vector @@ -225,11 +226,11 @@ class MaskedMedianPool1d(nn.Module): mask = mask.unfold(2, self.kernel_size, self.stride) x = x.contiguous().view(x.size()[:3] + (-1,)) - mask = mask.contiguous().view(mask.size()[:3] + (-1,)) + mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device) # Combine the mask with the input tensor #x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf"))) - x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")])) + x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device)) # Sort the masked tensor along the last dimension x_sorted, _ = torch.sort(x_masked, dim=-1) @@ -260,6 +261,7 @@ class CrepePitchExtractor(BasePitchExtractor): f0_max: float = 1100.0, threshold: float = 0.05, keep_zeros: bool = False, + device = None, model: Literal["full", "tiny"] = "full", use_fast_filters: bool = True, ): @@ -269,10 +271,13 @@ class CrepePitchExtractor(BasePitchExtractor): self.model = model self.use_fast_filters = use_fast_filters self.hop_length = hop_length - + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) if self.use_fast_filters: - self.median_filter = MaskedMedianPool1d(3, 1, 1) - self.mean_filter = MaskedAvgPool1d(3, 1, 1) + self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device) + self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device) def __call__(self, x, sampling_rate=44100, pad_to=None): """Extract pitch using crepe. @@ -290,7 +295,7 @@ class CrepePitchExtractor(BasePitchExtractor): assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor." assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels." - + x = x.to(self.dev) f0, pd = torchcrepe.predict( x, sampling_rate, diff --git a/utils.py b/utils.py index 926ea8c..eceed38 100644 --- a/utils.py +++ b/utils.py @@ -81,7 +81,7 @@ def normalize_f0(f0, x_mask, uv, random_scale=True): exit(0) return f0_norm * x_mask -def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): +def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None): x = wav_numpy if p_len is None: p_len = x.shape[0]//hop_length @@ -90,7 +90,7 @@ def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_len f0_min = 50 f0_max = 1100 - F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max) + F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device) f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len) return f0,uv