Updata vol infer

This commit is contained in:
ylzz1997 2023-05-28 22:29:27 +08:00
parent 358369d032
commit 649ecd4c7e
1 changed files with 6 additions and 4 deletions

View File

@ -136,6 +136,7 @@ class Svc(object):
self.target_sample = self.hps_ms.data.sampling_rate self.target_sample = self.hps_ms.data.sampling_rate
self.hop_size = self.hps_ms.data.hop_length self.hop_size = self.hps_ms.data.hop_length
self.spk2id = self.hps_ms.spk self.spk2id = self.hps_ms.spk
self.vol_embedding = self.hps_ms.model.vol_embedding
try: try:
self.speech_encoder = self.hps_ms.model.speech_encoder self.speech_encoder = self.hps_ms.model.speech_encoder
except Exception as e: except Exception as e:
@ -233,16 +234,17 @@ class Svc(object):
c = c.half() c = c.half()
with torch.no_grad(): with torch.no_grad():
start = time.time() start = time.time()
vol = None
if not self.only_diffusion: if not self.only_diffusion:
audio,f0 = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale) vol = self.volume_extractor.extract(audio[None,:])[None,:].to(self.dev) if self.vol_embedding else None
audio,f0 = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale,vol=vol)
audio = audio[0,0].data.float() audio = audio[0,0].data.float()
if self.shallow_diffusion: audio_mel = self.vocoder.extract(audio[None,:],self.target_sample) if self.shallow_diffusion else None
audio_mel = self.vocoder.extract(audio[None,:],self.target_sample)
else: else:
audio = torch.FloatTensor(wav).to(self.dev) audio = torch.FloatTensor(wav).to(self.dev)
audio_mel = None audio_mel = None
if self.only_diffusion or self.shallow_diffusion: if self.only_diffusion or self.shallow_diffusion:
vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) if vol==None else vol[:,:,None]
f0 = f0[:,:,None] f0 = f0[:,:,None]
c = c.transpose(-1,-2) c = c.transpose(-1,-2)
audio_mel = self.diffusion_model( audio_mel = self.diffusion_model(