Updata vol emb

This commit is contained in:
ylzz1997 2023-05-28 21:47:32 +08:00
parent f5abe16a40
commit 358369d032
5 changed files with 52 additions and 25 deletions

View File

@ -56,7 +56,8 @@
"ssl_dim": 768,
"n_speakers": 200,
"speech_encoder":"vec768l12",
"speaker_embedding":false
"speaker_embedding":false,
"vol_embedding":false
},
"spk": {
"nyaru": 0,

View File

@ -23,7 +23,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
3) computes spectrograms from audio files.
"""
def __init__(self, audiopaths, hparams, all_in_mem: bool = False):
def __init__(self, audiopaths, hparams, all_in_mem: bool = False, vol_aug: bool = False):
self.audiopaths = load_filepaths_and_text(audiopaths)
self.max_wav_value = hparams.data.max_wav_value
self.sampling_rate = hparams.data.sampling_rate
@ -34,6 +34,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
self.use_sr = hparams.train.use_sr
self.spec_len = hparams.train.max_speclen
self.spk_map = hparams.spk
self.vol_emb = hparams.model.vol_embedding
random.seed(1234)
random.shuffle(self.audiopaths)
@ -72,17 +73,23 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
c = torch.load(filename+ ".soft.pt")
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
if self.vol_emb:
volume_path = filename + ".vol.npy"
volume = np.load(volume_path)
volume = torch.from_numpy(volume).float()
else:
volume = None
lmin = min(c.size(-1), spec.size(-1))
assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename)
assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
audio_norm = audio_norm[:, :lmin * self.hop_length]
if volume!= None:
volume = volume[:lmin]
return c, f0, spec, audio_norm, spk, uv, volume
return c, f0, spec, audio_norm, spk, uv
def random_slice(self, c, f0, spec, audio_norm, spk, uv):
def random_slice(self, c, f0, spec, audio_norm, spk, uv, volume):
# if spec.shape[1] < 30:
# print("skip too short audio:", filename)
# return None
@ -91,8 +98,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
end = start + 790
spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end]
audio_norm = audio_norm[:, start * self.hop_length : end * self.hop_length]
return c, f0, spec, audio_norm, spk, uv
if volume !=None:
volume = volume[start:end]
return c, f0, spec, audio_norm, spk, uv,volume
def __getitem__(self, index):
if self.all_in_mem:
@ -124,12 +132,14 @@ class TextAudioCollate:
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
spkids = torch.LongTensor(len(batch), 1)
uv_padded = torch.FloatTensor(len(batch), max_c_len)
volume_padded = torch.FloatTensor(len(batch), max_c_len)
c_padded.zero_()
spec_padded.zero_()
f0_padded.zero_()
wav_padded.zero_()
uv_padded.zero_()
volume_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
@ -151,5 +161,9 @@ class TextAudioCollate:
uv = row[5]
uv_padded[i, :uv.size(0)] = uv
return c_padded, f0_padded, spec_padded, wav_padded, spkids, lengths, uv_padded
volume = row[6]
if volume != None:
volume_padded[i, :volume.size(0)] = volume
else :
volume_padded = None
return c_padded, f0_padded, spec_padded, wav_padded, spkids, lengths, uv_padded, volume_padded

View File

@ -16,7 +16,6 @@ from modules.commons import init_weights, get_padding
from vdecoder.hifigan.models import Generator
from utils import f0_to_coarse
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
@ -253,7 +252,6 @@ class SpeakerEncoder(torch.nn.Module):
return embed
class F0Decoder(nn.Module):
def __init__(self,
out_channels,
@ -322,6 +320,7 @@ class SynthesizerTrn(nn.Module):
ssl_dim,
n_speakers,
sampling_rate=44100,
vol_embedding=False,
**kwargs):
super().__init__()
@ -342,7 +341,10 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size
self.gin_channels = gin_channels
self.ssl_dim = ssl_dim
self.vol_embedding = vol_embedding
self.emb_g = nn.Embedding(n_speakers, gin_channels)
if vol_embedding:
self.emb_vol = nn.Linear(1, hidden_channels)
self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
@ -389,11 +391,15 @@ class SynthesizerTrn(nn.Module):
self.speaker_map = self.speaker_map.unsqueeze(0)
self.character_mix = True
def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None):
g = self.emb_g(g).transpose(1, 2)
def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None, vol = None):
g = self.emb_g(g).transpose(1,2)
# vol proj
vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0
# ssl prenet
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol
# f0 predict
lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
@ -412,20 +418,24 @@ class SynthesizerTrn(nn.Module):
o = self.dec(z_slice, g=g, f0=pitch_slice)
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0
def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False):
def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False, vol = None):
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
g = self.emb_g(g).transpose(1, 2)
g = self.emb_g(g).transpose(1,2)
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
# vol proj
vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol
if predict_f0:
lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False)
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale)
z = self.flow(z_p, c_mask, g=g, reverse=True)
o = self.dec(z * c_mask, g=g, f0=f0)
return o, f0
return o,f0

View File

@ -78,12 +78,14 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_path)
if diff:
if diff or hps.model.vol_embedding:
volume_path = filename + ".vol.npy"
volume_extractor = utils.Volume_Extractor(hop_length)
if not os.path.exists(volume_path):
volume = volume_extractor.extract(audio_norm)
np.save(volume_path, volume.to('cpu').numpy())
if diff:
mel_path = filename + ".mel.npy"
if not os.path.exists(mel_path) and mel_extractor is not None:
mel_t = mel_extractor.extract(audio_norm.to(device), sampling_rate)

View File

@ -155,7 +155,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train()
net_d.train()
for batch_idx, items in enumerate(train_loader):
c, f0, spec, y, spk, lengths, uv = items
c, f0, spec, y, spk, lengths, uv,volume = items
g = spk.cuda(rank, non_blocking=True)
spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True)
c = c.cuda(rank, non_blocking=True)
@ -173,7 +173,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
with autocast(enabled=hps.train.fp16_run):
y_hat, ids_slice, z_mask, \
(z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths,
spec_lengths=lengths)
spec_lengths=lengths,vol = volume)
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
y_hat_mel = mel_spectrogram_torch(