diff --git a/flask_api.py b/flask_api.py index 8cc236a..b3f1e06 100644 --- a/flask_api.py +++ b/flask_api.py @@ -30,10 +30,13 @@ def voice_change_model(): # 模型推理 if raw_infer: - out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path) + # out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path) + out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0, + auto_predict_f0=False, noice_scale=0.4, f0_filter=False) tar_audio = torchaudio.functional.resample(out_audio, svc_model.target_sample, daw_sample) else: - out_audio = svc.process(svc_model, speaker_id, f_pitch_change, input_wav_path) + out_audio = svc.process(svc_model, speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0, + auto_predict_f0=False, noice_scale=0.4, f0_filter=False) tar_audio = torchaudio.functional.resample(torch.from_numpy(out_audio), svc_model.target_sample, daw_sample) # 返回音频 out_wav_path = io.BytesIO() @@ -50,7 +53,8 @@ if __name__ == '__main__': # 每个模型和config是唯一对应的 model_name = "logs/32k/G_174000-Copy1.pth" config_name = "configs/config.json" - svc_model = Svc(model_name, config_name) + cluster_model_path = "logs/44k/kmeans_10000.pt" + svc_model = Svc(model_name, config_name, cluster_model_path=cluster_model_path) svc = RealTimeVC() # 此处与vst插件对应,不建议更改 app.run(port=6842, host="0.0.0.0", debug=False, threaded=False) diff --git a/inference/infer_tool.py b/inference/infer_tool.py index dd1799a..0ea8397 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -108,6 +108,9 @@ def split_list_by_n(list_collection, n, pre=0): yield list_collection[i-pre if i-pre>=0 else i: i + n] +class F0FilterException(Exception): + pass + class Svc(object): def __init__(self, net_g_path, config_path, device=None, @@ -142,11 +145,15 @@ class Svc(object): - def get_unit_f0(self, in_path, tran, cluster_infer_ratio, speaker): + def get_unit_f0(self, in_path, tran, cluster_infer_ratio, speaker, f0_filter): wav, sr = librosa.load(in_path, sr=self.target_sample) f0 = utils.compute_f0_parselmouth(wav, sampling_rate=self.target_sample, hop_length=self.hop_size) + + if f0_filter and sum(f0) == 0: + raise F0FilterException("未检测到人声") + f0, uv = utils.interpolate_f0(f0) f0 = torch.FloatTensor(f0) uv = torch.FloatTensor(uv) @@ -170,13 +177,15 @@ class Svc(object): def infer(self, speaker, tran, raw_path, cluster_infer_ratio=0, auto_predict_f0=False, - noice_scale=0.4): + noice_scale=0.4, + f0_filter=False): + speaker_id = self.spk2id.__dict__.get(speaker) if not speaker_id and type(speaker) is int: if len(self.spk2id.__dict__) >= speaker: speaker_id = speaker sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0) - c, f0, uv = self.get_unit_f0(raw_path, tran, cluster_infer_ratio, speaker) + c, f0, uv = self.get_unit_f0(raw_path, tran, cluster_infer_ratio, speaker, f0_filter) if "half" in self.net_g_path and torch.cuda.is_available(): c = c.half() with torch.no_grad(): @@ -185,7 +194,7 @@ class Svc(object): use_time = time.time() - start print("vits use time:{}".format(use_time)) return audio, audio.shape[-1] - + def clear_empty(self): # 清理显存 torch.cuda.empty_cache() @@ -252,14 +261,25 @@ class RealTimeVC: """输入输出都是1维numpy 音频波形数组""" - def process(self, svc_model, speaker_id, f_pitch_change, input_wav_path): + def process(self, svc_model, speaker_id, f_pitch_change, input_wav_path, + cluster_infer_ratio=0, + auto_predict_f0=False, + noice_scale=0.4, + f0_filter=False): + import maad audio, sr = torchaudio.load(input_wav_path) audio = audio.cpu().numpy()[0] temp_wav = io.BytesIO() if self.last_chunk is None: input_wav_path.seek(0) - audio, sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path) + + audio, sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path, + cluster_infer_ratio=cluster_infer_ratio, + auto_predict_f0=auto_predict_f0, + noice_scale=noice_scale, + f0_filter=f0_filter) + audio = audio.cpu().numpy() self.last_chunk = audio[-self.pre_len:] self.last_o = audio @@ -268,7 +288,13 @@ class RealTimeVC: audio = np.concatenate([self.last_chunk, audio]) soundfile.write(temp_wav, audio, sr, format="wav") temp_wav.seek(0) - audio, sr = svc_model.infer(speaker_id, f_pitch_change, temp_wav) + + audio, sr = svc_model.infer(speaker_id, f_pitch_change, temp_wav, + cluster_infer_ratio=cluster_infer_ratio, + auto_predict_f0=auto_predict_f0, + noice_scale=noice_scale, + f0_filter=f0_filter) + audio = audio.cpu().numpy() ret = maad.util.crossfade(self.last_o, audio, self.pre_len) self.last_chunk = audio[-self.pre_len:]