diff --git a/configs_template/config_template.json b/configs_template/config_template.json index 670329c..377a5ec 100644 --- a/configs_template/config_template.json +++ b/configs_template/config_template.json @@ -12,6 +12,7 @@ "eps": 1e-09, "batch_size": 6, "fp16_run": false, + "half_type": "fp16", "lr_decay": 0.999875, "segment_size": 10240, "init_lr_ratio": 1, @@ -53,6 +54,7 @@ "upsample_initial_channel": 512, "upsample_kernel_sizes": [16,16, 4, 4, 4], "n_layers_q": 3, + "n_flow_layer": 4, "use_spectral_norm": false, "gin_channels": 768, "ssl_dim": 768, @@ -60,7 +62,9 @@ "vocoder_name":"nsf-hifigan", "speech_encoder":"vec768l12", "speaker_embedding":false, - "vol_embedding":false + "vol_embedding":false, + "use_depthwise_conv":false, + "use_automatic_f0_prediction": true }, "spk": { "nyaru": 0, diff --git a/models.py b/models.py index 67909a5..5adbded 100644 --- a/models.py +++ b/models.py @@ -318,6 +318,9 @@ class SynthesizerTrn(nn.Module): sampling_rate=44100, vol_embedding=False, vocoder_name = "nsf-hifigan", + use_depthwise_conv = False, + use_automatic_f0_prediction = True, + n_flow_layer = 4, **kwargs): super().__init__() @@ -340,6 +343,8 @@ class SynthesizerTrn(nn.Module): self.ssl_dim = ssl_dim self.vol_embedding = vol_embedding self.emb_g = nn.Embedding(n_speakers, gin_channels) + self.use_depthwise_conv = use_depthwise_conv + self.use_automatic_f0_prediction = use_automatic_f0_prediction if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels) @@ -364,9 +369,11 @@ class SynthesizerTrn(nn.Module): "upsample_initial_channel": upsample_initial_channel, "upsample_kernel_sizes": upsample_kernel_sizes, "gin_channels": gin_channels, + "use_depthwise_conv":use_depthwise_conv } - + modules.set_Conv1dModel(self.use_depthwise_conv) + if vocoder_name == "nsf-hifigan": from vdecoder.hifigan.models import Generator self.dec = Generator(h=hps) @@ -379,17 +386,18 @@ class SynthesizerTrn(nn.Module): self.dec = Generator(h=hps) self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - self.f0_decoder = F0Decoder( - 1, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - spk_channels=gin_channels - ) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels) + if self.use_automatic_f0_prediction: + self.f0_decoder = F0Decoder( + 1, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + spk_channels=gin_channels + ) self.emb_uv = nn.Embedding(2, hidden_channels) self.character_mix = False @@ -409,12 +417,16 @@ class SynthesizerTrn(nn.Module): # ssl prenet x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol - + # f0 predict - lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 - norm_lf0 = utils.normalize_f0(lf0, x_mask, uv) - pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) - + if self.use_automatic_f0_prediction: + lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 + norm_lf0 = utils.normalize_f0(lf0, x_mask, uv) + pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) + else: + lf0 = 0 + norm_lf0 = 0 + pred_lf0 = 0 # encoder z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0)) z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) @@ -428,6 +440,7 @@ class SynthesizerTrn(nn.Module): return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 + @torch.no_grad() def infer(self, c, f0, uv, g=None, noice_scale=0.35, seed=52468, predict_f0=False, vol = None): if c.device == torch.device("cuda"): @@ -449,11 +462,13 @@ class SynthesizerTrn(nn.Module): x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) # vol proj + vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0 x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol + - if predict_f0: + if self.use_automatic_f0_prediction and predict_f0: lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False) pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) diff --git a/modules/DSConv.py b/modules/DSConv.py new file mode 100644 index 0000000..9909521 --- /dev/null +++ b/modules/DSConv.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm, remove_weight_norm + +class Depthwise_Separable_Conv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride = 1, + padding = 0, + dilation = 1, + bias = True, + padding_mode = 'zeros', # TODO: refine this type + device=None, + dtype=None + ): + super().__init__() + self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) + self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name = 'weight') + self.point_conv = weight_norm(self.point_conv, name = 'weight') + + def remove_weight_norm(self): + self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight') + self.point_conv = remove_weight_norm(self.point_conv, name = 'weight') + +class Depthwise_Separable_TransposeConv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride = 1, + padding = 0, + output_padding = 0, + bias = True, + dilation = 1, + padding_mode = 'zeros', # TODO: refine this type + device=None, + dtype=None + ): + super().__init__() + self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) + self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name = 'weight') + self.point_conv = weight_norm(self.point_conv, name = 'weight') + + def remove_weight_norm(self): + remove_weight_norm(self.depth_conv, name = 'weight') + remove_weight_norm(self.point_conv, name = 'weight') + + +def weight_norm_modules(module, name = 'weight', dim = 0): + if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): + module.weight_norm() + return module + else: + return weight_norm(module,name,dim) + +def remove_weight_norm_modules(module, name = 'weight'): + if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): + module.remove_weight_norm() + else: + remove_weight_norm(module,name) \ No newline at end of file diff --git a/modules/commons.py b/modules/commons.py index d8ba67d..761379d 100644 --- a/modules/commons.py +++ b/modules/commons.py @@ -24,10 +24,12 @@ def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ - if classname.find("Conv") != -1: + if "Depthwise_Separable" in classname: + m.depth_conv.weight.data.normal_(mean, std) + m.point_conv.weight.data.normal_(mean, std) + elif classname.find("Conv") != -1: m.weight.data.normal_(mean, std) - def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) diff --git a/modules/mel_processing.py b/modules/mel_processing.py index 0795b05..8ac0717 100644 --- a/modules/mel_processing.py +++ b/modules/mel_processing.py @@ -51,10 +51,13 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) - + + y_dtype = y.dtype + if y.dtype == torch.bfloat16: y = y.to(torch.float32) spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) - spec = torch.view_as_real(spec) + spec = torch.view_as_real(spec).to(y_dtype) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec diff --git a/modules/modules.py b/modules/modules.py index 6af6227..ba67df6 100644 --- a/modules/modules.py +++ b/modules/modules.py @@ -2,13 +2,20 @@ import torch from torch import nn from torch.nn import Conv1d from torch.nn import functional as F -from torch.nn.utils import remove_weight_norm, weight_norm + +from modules.DSConv import weight_norm_modules, remove_weight_norm_modules, Depthwise_Separable_Conv1D import modules.commons as commons -from modules.commons import get_padding, init_weights +from modules.commons import init_weights, get_padding LRELU_SLOPE = 0.1 +Conv1dModel = nn.Conv1d + +def set_Conv1dModel(use_depthwise_conv): + global Conv1dModel + Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d + class LayerNorm(nn.Module): def __init__(self, channels, eps=1e-5): @@ -38,13 +45,13 @@ class ConvReluNorm(nn.Module): self.conv_layers = nn.ModuleList() self.norm_layers = nn.ModuleList() - self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.conv_layers.append(Conv1dModel(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) self.norm_layers.append(LayerNorm(hidden_channels)) self.relu_drop = nn.Sequential( nn.ReLU(), nn.Dropout(p_dropout)) for _ in range(n_layers-1): - self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.conv_layers.append(Conv1dModel(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() @@ -60,47 +67,6 @@ class ConvReluNorm(nn.Module): return x * x_mask -class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size ** i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, - groups=channels, dilation=dilation, padding=padding - )) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - class WN(torch.nn.Module): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): super(WN, self).__init__() @@ -118,14 +84,14 @@ class WN(torch.nn.Module): if gin_channels != 0: cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + self.cond_layer = weight_norm_modules(cond_layer, name='weight') for i in range(n_layers): dilation = dilation_rate ** i padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, + in_layer = Conv1dModel(hidden_channels, 2*hidden_channels, kernel_size, dilation=dilation, padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + in_layer = weight_norm_modules(in_layer, name='weight') self.in_layers.append(in_layer) # last one is not necessary @@ -135,7 +101,7 @@ class WN(torch.nn.Module): res_skip_channels = hidden_channels res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + res_skip_layer = weight_norm_modules(res_skip_layer, name='weight') self.res_skip_layers.append(res_skip_layer) def forward(self, x, x_mask, g=None, **kwargs): @@ -170,32 +136,32 @@ class WN(torch.nn.Module): def remove_weight_norm(self): if self.gin_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) + remove_weight_norm_modules(self.cond_layer) for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) + remove_weight_norm_modules(l) for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) + remove_weight_norm_modules(l) class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() self.convs1 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))) ]) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) ]) self.convs2.apply(init_weights) @@ -217,18 +183,18 @@ class ResBlock1(torch.nn.Module): def remove_weight_norm(self): for l in self.convs1: - remove_weight_norm(l) + remove_weight_norm_modules(l) for l in self.convs2: - remove_weight_norm(l) + remove_weight_norm_modules(l) class ResBlock2(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super(ResBlock2, self).__init__() self.convs = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))) ]) self.convs.apply(init_weights) @@ -246,7 +212,7 @@ class ResBlock2(torch.nn.Module): def remove_weight_norm(self): for l in self.convs: - remove_weight_norm(l) + remove_weight_norm_modules(l) class Log(nn.Module): diff --git a/train.py b/train.py index d2356b5..dbdec11 100644 --- a/train.py +++ b/train.py @@ -52,7 +52,7 @@ def run(rank, n_gpus, hps): utils.check_git_hash(hps.model_dir) writer = SummaryWriter(log_dir=hps.model_dir) writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) - + # for pytorch on win, backend use gloo dist.init_process_group(backend= 'gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus, rank=rank) torch.manual_seed(hps.train.seed) @@ -139,6 +139,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade train_loader, eval_loader = loaders if writers is not None: writer, writer_eval = writers + + half_type = torch.float16 if hps.train.half_type=="fp16" else torch.bfloat16 # train_loader.batch_sampler.set_epoch(epoch) global global_step @@ -160,8 +162,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade hps.data.sampling_rate, hps.data.mel_fmin, hps.data.mel_fmax) - - with autocast(enabled=hps.train.fp16_run): + + with autocast(enabled=hps.train.fp16_run, dtype=half_type): y_hat, ids_slice, z_mask, \ (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths, spec_lengths=lengths,vol = volume) @@ -182,25 +184,26 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade # Discriminator y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) - with autocast(enabled=False): + with autocast(enabled=False, dtype=half_type): loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) loss_disc_all = loss_disc - + optim_d.zero_grad() scaler.scale(loss_disc_all).backward() scaler.unscale_(optim_d) grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) scaler.step(optim_d) + - with autocast(enabled=hps.train.fp16_run): + with autocast(enabled=hps.train.fp16_run, dtype=half_type): # Generator y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) - with autocast(enabled=False): + with autocast(enabled=False, dtype=half_type): loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) - loss_lf0 = F.mse_loss(pred_lf0, lf0) + loss_lf0 = F.mse_loss(pred_lf0, lf0) if net_g.module.use_automatic_f0_prediction else 0 loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0 optim_g.zero_grad() scaler.scale(loss_gen_all).backward() @@ -232,13 +235,17 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade image_dict = { "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), - "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), - "all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), - pred_lf0[0, 0, :].detach().cpu().numpy()), - "all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), - norm_lf0[0, 0, :].detach().cpu().numpy()) + "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()) } + if net_g.module.use_automatic_f0_prediction: + image_dict.update({ + "all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), + pred_lf0[0, 0, :].detach().cpu().numpy()), + "all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), + norm_lf0[0, 0, :].detach().cpu().numpy()) + }) + utils.summarize( writer=writer, global_step=global_step, @@ -319,4 +326,4 @@ def evaluate(hps, generator, eval_loader, writer_eval): if __name__ == "__main__": - main() + main() \ No newline at end of file