diff --git a/inference/infer_tool.py b/inference/infer_tool.py index 34957ef..84e512e 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -122,6 +122,7 @@ class Svc(object): diffusion_config_path="configs/diffusion.yaml", shallow_diffusion = False, only_diffusion = False, + spk_mix_enable = False ): self.net_g_path = net_g_path self.only_diffusion = only_diffusion @@ -150,14 +151,15 @@ class Svc(object): self.hop_size = self.diffusion_args.data.block_size self.spk2id = self.diffusion_args.spk self.speech_encoder = self.diffusion_args.data.encoder - self.diffusion_model.init_spkmix(len(self.spk2id)) + if spk_mix_enable: + self.diffusion_model.init_spkmix(len(self.spk2id)) else: print("No diffusion model or config found. Shallow diffusion mode will False") self.shallow_diffusion = self.only_diffusion = False # load hubert and model if not self.only_diffusion: - self.load_model() + self.load_model(spk_mix_enable) self.hubert_model = utils.get_speech_encoder(self.speech_encoder,device=self.dev) self.volume_extractor = utils.Volume_Extractor(self.hop_size) else: @@ -171,7 +173,7 @@ class Svc(object): from modules.enhancer import Enhancer self.enhancer = Enhancer('nsf-hifigan', 'pretrain/nsf_hifigan/model',device=self.dev) - def load_model(self): + def load_model(self, spk_mix_enable=False): # get model configuration self.net_g_ms = SynthesizerTrn( self.hps_ms.data.filter_length // 2 + 1, @@ -182,7 +184,8 @@ class Svc(object): _ = self.net_g_ms.half().eval().to(self.dev) else: _ = self.net_g_ms.eval().to(self.dev) - self.net_g_ms.EnableCharacterMix(len(self.spk2id), self.dev) + if spk_mix_enable: + self.net_g_ms.EnableCharacterMix(len(self.spk2id), self.dev) def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter ,f0_predictor,cr_threshold=0.05): diff --git a/inference_main.py b/inference_main.py index 37640cd..3997872 100644 --- a/inference_main.py +++ b/inference_main.py @@ -80,7 +80,7 @@ def main(): only_diffusion = args.only_diffusion shallow_diffusion = args.shallow_diffusion use_spk_mix = args.use_spk_mix - svc_model = Svc(args.model_path, args.config_path, args.device, args.cluster_model_path,enhance,diffusion_model_path,diffusion_config_path,shallow_diffusion,only_diffusion) + svc_model = Svc(args.model_path, args.config_path, args.device, args.cluster_model_path,enhance,diffusion_model_path,diffusion_config_path,shallow_diffusion,only_diffusion,use_spk_mix) infer_tool.mkdir(["raw", "results"]) if len(spk_mix_map)<=1: use_spk_mix = False