diff --git a/inference/infer_tool.py b/inference/infer_tool.py index bdfb8fb..710b06a 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -203,9 +203,10 @@ class Svc(object): def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter ,f0_predictor,cr_threshold=0.05): - f0_predictor_object = utils.get_f0_predictor(f0_predictor,hop_length=self.hop_size,sampling_rate=self.target_sample,device=self.dev,threshold=cr_threshold) - - f0, uv = f0_predictor_object.compute_f0_uv(wav) + if not hasattr(self,"f0_predictor_object") or self.f0_predictor_object is None or f0_predictor != self.f0_predictor_object.name: + self.f0_predictor_object = utils.get_f0_predictor(f0_predictor,hop_length=self.hop_size,sampling_rate=self.target_sample,device=self.dev,threshold=cr_threshold) + f0, uv = self.f0_predictor_object.compute_f0_uv(wav) + if f0_filter and sum(f0) == 0: raise F0FilterException("No voice detected") f0 = torch.FloatTensor(f0).to(self.dev) diff --git a/modules/F0Predictor/CrepeF0Predictor.py b/modules/F0Predictor/CrepeF0Predictor.py index 086ca10..c0854b6 100644 --- a/modules/F0Predictor/CrepeF0Predictor.py +++ b/modules/F0Predictor/CrepeF0Predictor.py @@ -13,6 +13,7 @@ class CrepeF0Predictor(F0Predictor): self.device = device self.threshold = threshold self.sampling_rate = sampling_rate + self.name = "crepe" def compute_f0(self,wav,p_len=None): x = torch.FloatTensor(wav).to(self.device) diff --git a/modules/F0Predictor/DioF0Predictor.py b/modules/F0Predictor/DioF0Predictor.py index ef470a4..178dd2e 100644 --- a/modules/F0Predictor/DioF0Predictor.py +++ b/modules/F0Predictor/DioF0Predictor.py @@ -10,6 +10,7 @@ class DioF0Predictor(F0Predictor): self.f0_min = f0_min self.f0_max = f0_max self.sampling_rate = sampling_rate + self.name = "dio" def interpolate_f0(self,f0): ''' diff --git a/modules/F0Predictor/FCPEF0Predictor.py b/modules/F0Predictor/FCPEF0Predictor.py index 57d5a6c..91913c7 100644 --- a/modules/F0Predictor/FCPEF0Predictor.py +++ b/modules/F0Predictor/FCPEF0Predictor.py @@ -23,6 +23,7 @@ class FCPEF0Predictor(F0Predictor): self.threshold = threshold self.sampling_rate = sampling_rate self.dtype = dtype + self.name = "fcpe" def repeat_expand( self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" @@ -89,7 +90,7 @@ class FCPEF0Predictor(F0Predictor): p_len = x.shape[0] // self.hop_length else: assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" - f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold) + f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0] if torch.all(f0 == 0): rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) return rtn, rtn @@ -101,7 +102,7 @@ class FCPEF0Predictor(F0Predictor): p_len = x.shape[0] // self.hop_length else: assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" - f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold) + f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0] if torch.all(f0 == 0): rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) return rtn, rtn diff --git a/modules/F0Predictor/HarvestF0Predictor.py b/modules/F0Predictor/HarvestF0Predictor.py index fe279f6..f36b332 100644 --- a/modules/F0Predictor/HarvestF0Predictor.py +++ b/modules/F0Predictor/HarvestF0Predictor.py @@ -10,6 +10,7 @@ class HarvestF0Predictor(F0Predictor): self.f0_min = f0_min self.f0_max = f0_max self.sampling_rate = sampling_rate + self.name = "harvest" def interpolate_f0(self,f0): ''' diff --git a/modules/F0Predictor/PMF0Predictor.py b/modules/F0Predictor/PMF0Predictor.py index cb7355f..2af3f6e 100644 --- a/modules/F0Predictor/PMF0Predictor.py +++ b/modules/F0Predictor/PMF0Predictor.py @@ -10,7 +10,7 @@ class PMF0Predictor(F0Predictor): self.f0_min = f0_min self.f0_max = f0_max self.sampling_rate = sampling_rate - + self.name = "pm" def interpolate_f0(self,f0): ''' diff --git a/modules/F0Predictor/RMVPEF0Predictor.py b/modules/F0Predictor/RMVPEF0Predictor.py index 63412ae..9313887 100644 --- a/modules/F0Predictor/RMVPEF0Predictor.py +++ b/modules/F0Predictor/RMVPEF0Predictor.py @@ -22,6 +22,7 @@ class RMVPEF0Predictor(F0Predictor): self.threshold = threshold self.sampling_rate = sampling_rate self.dtype = dtype + self.name = "rmvpe" def repeat_expand( self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" diff --git a/modules/F0Predictor/fcpe/model.py b/modules/F0Predictor/fcpe/model.py index 9e2e0c0..38a8304 100644 --- a/modules/F0Predictor/fcpe/model.py +++ b/modules/F0Predictor/fcpe/model.py @@ -1,10 +1,7 @@ -import os - import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import yaml from torch.nn.utils import weight_norm from torchaudio.transforms import Resample @@ -146,10 +143,11 @@ class FCPE(nn.Module): class FCPEInfer: def __init__(self, model_path, device=None, dtype=torch.float32): - config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') - with open(config_file, "r") as config: - args = yaml.safe_load(config) - self.args = DotDict(args) + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + ckpt = torch.load(model_path, map_location=torch.device(self.device)) + self.args = DotDict(ckpt["config"]) self.dtype = dtype model = FCPE( input_channel=self.args.model.input_channel, @@ -167,25 +165,19 @@ class FCPEInfer: f0_min=self.args.model.f0_min, confidence=self.args.model.confidence, ) - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = device ckpt = torch.load(model_path, map_location=torch.device(self.device)) model.to(self.device).to(self.dtype) model.load_state_dict(ckpt['model']) model.eval() self.model = model self.wav2mel = Wav2Mel(self.args) - self.args = args @torch.no_grad() def __call__(self, audio, sr, threshold=0.05): self.model.threshold = threshold - audio = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) + audio = audio[None,:] mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype) - mel_f0 = self.model(mel=mel, infer=True, return_hz_f0=True) - # f0 = (mel_f0.exp() - 1) * 700 - f0 = mel_f0 + f0 = self.model(mel=mel, infer=True, return_hz_f0=True) return f0 diff --git a/utils.py b/utils.py index eba11db..95b6d88 100644 --- a/utils.py +++ b/utils.py @@ -102,8 +102,8 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs): from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) elif f0_predictor == "fcpe": - from modules.F0Predictor.FCPEF0Predictor import FCEF0Predictor - f0_predictor_object = FCEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) + from modules.F0Predictor.FCPEF0Predictor import FCPEF0Predictor + f0_predictor_object = FCPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) else: raise Exception("Unknown f0 predictor") return f0_predictor_object