Updata Cuda F0 Filter

This commit is contained in:
ylzz1997 2023-04-04 16:48:48 +08:00
parent b5ea92f2af
commit 20a7cf5068
3 changed files with 15 additions and 10 deletions

View File

@ -150,7 +150,7 @@ class Svc(object):
wav, sr = librosa.load(in_path, sr=self.target_sample)
if F0_mean_pooling == True:
f0, uv = utils.compute_f0_uv_torchcrepe(torch.FloatTensor(wav), sampling_rate=self.target_sample, hop_length=self.hop_size)
f0, uv = utils.compute_f0_uv_torchcrepe(torch.FloatTensor(wav), sampling_rate=self.target_sample, hop_length=self.hop_size,device=self.dev)
if f0_filter and sum(f0) == 0:
raise F0FilterException("未检测到人声")
f0 = torch.from_numpy(f0).float()

View File

@ -89,7 +89,7 @@ class BasePitchExtractor:
if self.keep_zeros:
return f0
vuv_vector = np.zeros_like(f0)
vuv_vector = torch.zeros_like(f0)
vuv_vector[f0 > 0.0] = 1.0
vuv_vector[f0 <= 0.0] = 0.0
@ -107,6 +107,7 @@ class BasePitchExtractor:
# 大概可以用 torch 重写?
f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
vuv_vector = vuv_vector.cpu().numpy()
vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
return f0,vuv_vector
@ -225,11 +226,11 @@ class MaskedMedianPool1d(nn.Module):
mask = mask.unfold(2, self.kernel_size, self.stride)
x = x.contiguous().view(x.size()[:3] + (-1,))
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device)
# Combine the mask with the input tensor
#x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf")))
x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]))
x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device))
# Sort the masked tensor along the last dimension
x_sorted, _ = torch.sort(x_masked, dim=-1)
@ -260,6 +261,7 @@ class CrepePitchExtractor(BasePitchExtractor):
f0_max: float = 1100.0,
threshold: float = 0.05,
keep_zeros: bool = False,
device = None,
model: Literal["full", "tiny"] = "full",
use_fast_filters: bool = True,
):
@ -269,10 +271,13 @@ class CrepePitchExtractor(BasePitchExtractor):
self.model = model
self.use_fast_filters = use_fast_filters
self.hop_length = hop_length
if device is None:
self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.dev = torch.device(device)
if self.use_fast_filters:
self.median_filter = MaskedMedianPool1d(3, 1, 1)
self.mean_filter = MaskedAvgPool1d(3, 1, 1)
self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device)
self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device)
def __call__(self, x, sampling_rate=44100, pad_to=None):
"""Extract pitch using crepe.
@ -290,7 +295,7 @@ class CrepePitchExtractor(BasePitchExtractor):
assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor."
assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels."
x = x.to(self.dev)
f0, pd = torchcrepe.predict(
x,
sampling_rate,

View File

@ -81,7 +81,7 @@ def normalize_f0(f0, x_mask, uv, random_scale=True):
exit(0)
return f0_norm * x_mask
def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None):
x = wav_numpy
if p_len is None:
p_len = x.shape[0]//hop_length
@ -90,7 +90,7 @@ def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_len
f0_min = 50
f0_max = 1100
F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max)
F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device)
f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len)
return f0,uv