commit
f152769ceb
|
@ -192,6 +192,7 @@ class Svc(object):
|
|||
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
|
||||
**self.hps_ms.model)
|
||||
_ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
|
||||
self.dtype = list(self.net_g_ms.parameters())[0].dtype
|
||||
if "half" in self.net_g_path and torch.cuda.is_available():
|
||||
_ = self.net_g_ms.half().eval().to(self.dev)
|
||||
else:
|
||||
|
@ -276,8 +277,9 @@ class Svc(object):
|
|||
sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
|
||||
c, f0, uv = self.get_unit_f0(wav, tran, cluster_infer_ratio, speaker, f0_filter,f0_predictor,cr_threshold=cr_threshold)
|
||||
n_frames = f0.size(1)
|
||||
if "half" in self.net_g_path and torch.cuda.is_available():
|
||||
c = c.half()
|
||||
c = c.to(self.dtype)
|
||||
f0 = f0.to(self.dtype)
|
||||
uv = uv.to(self.dtype)
|
||||
with torch.no_grad():
|
||||
start = time.time()
|
||||
vol = None
|
||||
|
@ -289,6 +291,10 @@ class Svc(object):
|
|||
else:
|
||||
audio = torch.FloatTensor(wav).to(self.dev)
|
||||
audio_mel = None
|
||||
if self.dtype != torch.float32:
|
||||
c = c.to(torch.float32)
|
||||
f0 = f0.to(torch.float32)
|
||||
uv = uv.to(torch.float32)
|
||||
if self.only_diffusion or self.shallow_diffusion:
|
||||
vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) if vol is None else vol[:,:,None]
|
||||
if self.shallow_diffusion and second_encoding:
|
||||
|
|
2
utils.py
2
utils.py
|
@ -153,6 +153,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|||
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
|
||||
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
saved_state_dict = checkpoint_dict['model']
|
||||
model = model.to(list(saved_state_dict.values())[0].dtype)
|
||||
if hasattr(model, 'module'):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
|
@ -165,6 +166,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|||
new_state_dict[k] = saved_state_dict[k]
|
||||
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
|
||||
except Exception:
|
||||
if "enc_q" not in k or "emb_g" not in k:
|
||||
print("error, %s is not in the checkpoint" % k)
|
||||
logger.info("%s is not in the checkpoint" % k)
|
||||
new_state_dict[k] = v
|
||||
|
|
|
@ -266,7 +266,7 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|||
"""
|
||||
# source for harmonic branch
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype)))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
|
|
|
@ -279,7 +279,7 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|||
"""
|
||||
# source for harmonic branch
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype)))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
|
|
Loading…
Reference in New Issue