From 7108c1e86c8622c91855d5a4c062638de924e51d Mon Sep 17 00:00:00 2001 From: YuriHead Date: Sun, 9 Jul 2023 03:38:41 +0800 Subject: [PATCH] Float16 model infer --- inference/infer_tool.py | 10 ++++++++-- utils.py | 8 +++++--- vdecoder/hifigan/models.py | 2 +- vdecoder/hifiganwithsnake/models.py | 2 +- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/inference/infer_tool.py b/inference/infer_tool.py index 6ab669d..ad9516a 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -192,6 +192,7 @@ class Svc(object): self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, **self.hps_ms.model) _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None) + self.dtype = list(self.net_g_ms.parameters())[0].dtype if "half" in self.net_g_path and torch.cuda.is_available(): _ = self.net_g_ms.half().eval().to(self.dev) else: @@ -276,8 +277,9 @@ class Svc(object): sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0) c, f0, uv = self.get_unit_f0(wav, tran, cluster_infer_ratio, speaker, f0_filter,f0_predictor,cr_threshold=cr_threshold) n_frames = f0.size(1) - if "half" in self.net_g_path and torch.cuda.is_available(): - c = c.half() + c = c.to(self.dtype) + f0 = f0.to(self.dtype) + uv = uv.to(self.dtype) with torch.no_grad(): start = time.time() vol = None @@ -289,6 +291,10 @@ class Svc(object): else: audio = torch.FloatTensor(wav).to(self.dev) audio_mel = None + if self.dtype != torch.float32: + c = c.to(torch.float32) + f0 = f0.to(torch.float32) + uv = uv.to(torch.float32) if self.only_diffusion or self.shallow_diffusion: vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) if vol is None else vol[:,:,None] if self.shallow_diffusion and second_encoding: diff --git a/utils.py b/utils.py index b9fe21a..88aa2cd 100644 --- a/utils.py +++ b/utils.py @@ -150,6 +150,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: optimizer.load_state_dict(checkpoint_dict['optimizer']) saved_state_dict = checkpoint_dict['model'] + model = model.to(list(saved_state_dict.values())[0].dtype) if hasattr(model, 'module'): state_dict = model.module.state_dict() else: @@ -162,9 +163,10 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False new_state_dict[k] = saved_state_dict[k] assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) except: # noqa: E722 I have no idea about this CC: @ylzz1997 - print("error, %s is not in the checkpoint" % k) - logger.info("%s is not in the checkpoint" % k) - new_state_dict[k] = v + if "enc_q" not in k or "emb_g" not in k: + print("error, %s is not in the checkpoint" % k) + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v if hasattr(model, 'module'): model.module.load_state_dict(new_state_dict) else: diff --git a/vdecoder/hifigan/models.py b/vdecoder/hifigan/models.py index 10eca45..8e79752 100644 --- a/vdecoder/hifigan/models.py +++ b/vdecoder/hifigan/models.py @@ -266,7 +266,7 @@ class SourceModuleHnNSF(torch.nn.Module): """ # source for harmonic branch sine_wavs, uv, _ = self.l_sin_gen(x) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) # source for noise branch, in the same shape as uv noise = torch.randn_like(uv) * self.sine_amp / 3 diff --git a/vdecoder/hifiganwithsnake/models.py b/vdecoder/hifiganwithsnake/models.py index ab9bcd1..9b64f9c 100644 --- a/vdecoder/hifiganwithsnake/models.py +++ b/vdecoder/hifiganwithsnake/models.py @@ -279,7 +279,7 @@ class SourceModuleHnNSF(torch.nn.Module): """ # source for harmonic branch sine_wavs, uv, _ = self.l_sin_gen(x) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) # source for noise branch, in the same shape as uv noise = torch.randn_like(uv) * self.sine_amp / 3