This commit is contained in:
白叶 藤原 2023-05-29 00:41:22 +08:00
parent 6491e7b1ef
commit 94ab3e4984
2 changed files with 8 additions and 5 deletions

View File

@ -122,6 +122,7 @@ class Svc(object):
diffusion_config_path="configs/diffusion.yaml",
shallow_diffusion = False,
only_diffusion = False,
spk_mix_enable = False
):
self.net_g_path = net_g_path
self.only_diffusion = only_diffusion
@ -150,14 +151,15 @@ class Svc(object):
self.hop_size = self.diffusion_args.data.block_size
self.spk2id = self.diffusion_args.spk
self.speech_encoder = self.diffusion_args.data.encoder
self.diffusion_model.init_spkmix(len(self.spk2id))
if spk_mix_enable:
self.diffusion_model.init_spkmix(len(self.spk2id))
else:
print("No diffusion model or config found. Shallow diffusion mode will False")
self.shallow_diffusion = self.only_diffusion = False
# load hubert and model
if not self.only_diffusion:
self.load_model()
self.load_model(spk_mix_enable)
self.hubert_model = utils.get_speech_encoder(self.speech_encoder,device=self.dev)
self.volume_extractor = utils.Volume_Extractor(self.hop_size)
else:
@ -171,7 +173,7 @@ class Svc(object):
from modules.enhancer import Enhancer
self.enhancer = Enhancer('nsf-hifigan', 'pretrain/nsf_hifigan/model',device=self.dev)
def load_model(self):
def load_model(self, spk_mix_enable=False):
# get model configuration
self.net_g_ms = SynthesizerTrn(
self.hps_ms.data.filter_length // 2 + 1,
@ -182,7 +184,8 @@ class Svc(object):
_ = self.net_g_ms.half().eval().to(self.dev)
else:
_ = self.net_g_ms.eval().to(self.dev)
self.net_g_ms.EnableCharacterMix(len(self.spk2id), self.dev)
if spk_mix_enable:
self.net_g_ms.EnableCharacterMix(len(self.spk2id), self.dev)
def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter ,f0_predictor,cr_threshold=0.05):

View File

@ -80,7 +80,7 @@ def main():
only_diffusion = args.only_diffusion
shallow_diffusion = args.shallow_diffusion
use_spk_mix = args.use_spk_mix
svc_model = Svc(args.model_path, args.config_path, args.device, args.cluster_model_path,enhance,diffusion_model_path,diffusion_config_path,shallow_diffusion,only_diffusion)
svc_model = Svc(args.model_path, args.config_path, args.device, args.cluster_model_path,enhance,diffusion_model_path,diffusion_config_path,shallow_diffusion,only_diffusion,use_spk_mix)
infer_tool.mkdir(["raw", "results"])
if len(spk_mix_map)<=1:
use_spk_mix = False