Merge pull request #279 from svc-develop-team/4.1-Latest

To Latest
This commit is contained in:
YuriHead 2023-07-09 03:46:54 +08:00 committed by GitHub
commit f152769ceb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 7 deletions

View File

@ -192,6 +192,7 @@ class Svc(object):
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
**self.hps_ms.model)
_ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
self.dtype = list(self.net_g_ms.parameters())[0].dtype
if "half" in self.net_g_path and torch.cuda.is_available():
_ = self.net_g_ms.half().eval().to(self.dev)
else:
@ -276,8 +277,9 @@ class Svc(object):
sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
c, f0, uv = self.get_unit_f0(wav, tran, cluster_infer_ratio, speaker, f0_filter,f0_predictor,cr_threshold=cr_threshold)
n_frames = f0.size(1)
if "half" in self.net_g_path and torch.cuda.is_available():
c = c.half()
c = c.to(self.dtype)
f0 = f0.to(self.dtype)
uv = uv.to(self.dtype)
with torch.no_grad():
start = time.time()
vol = None
@ -289,6 +291,10 @@ class Svc(object):
else:
audio = torch.FloatTensor(wav).to(self.dev)
audio_mel = None
if self.dtype != torch.float32:
c = c.to(torch.float32)
f0 = f0.to(torch.float32)
uv = uv.to(torch.float32)
if self.only_diffusion or self.shallow_diffusion:
vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) if vol is None else vol[:,:,None]
if self.shallow_diffusion and second_encoding:

View File

@ -153,6 +153,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
saved_state_dict = checkpoint_dict['model']
model = model.to(list(saved_state_dict.values())[0].dtype)
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
@ -165,9 +166,10 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
except Exception:
print("error, %s is not in the checkpoint" % k)
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if "enc_q" not in k or "emb_g" not in k:
print("error, %s is not in the checkpoint" % k)
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict)
else:

View File

@ -266,7 +266,7 @@ class SourceModuleHnNSF(torch.nn.Module):
"""
# source for harmonic branch
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype)))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3

View File

@ -279,7 +279,7 @@ class SourceModuleHnNSF(torch.nn.Module):
"""
# source for harmonic branch
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype)))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3