From 3160e7d84647d220e6471b3e0c62bda877351d09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Wed, 15 Mar 2023 15:47:59 +0800 Subject: [PATCH] Update model_onnx.py --- onnxexport/model_onnx.py | 1045 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 1045 insertions(+) diff --git a/onnxexport/model_onnx.py b/onnxexport/model_onnx.py index 8b13789..5b7eda8 100644 --- a/onnxexport/model_onnx.py +++ b/onnxexport/model_onnx.py @@ -1 +1,1046 @@ +import sys +import copy +import math +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +import numpy as np + +sys.path.append('../..') +import modules.commons as commons +import modules.modules as modules +import modules.attentions as attentions + +from modules.commons import init_weights, get_padding + +from modules.ddsp import mlp, gru, scale_function, remove_above_nyquist, upsample +from modules.ddsp import harmonic_synth, amp_to_impulse_response, fft_convolve +from modules.ddsp import resample +import utils + +from modules.stft import TorchSTFT + +import torch.distributions as D + +from modules.losses import ( + generator_loss, + discriminator_loss, + feature_loss, + kl_loss +) + +LRELU_SLOPE = 0.1 + + +class PostF0Decoder(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, spk_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = spk_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if spk_channels != 0: + self.cond = nn.Conv1d(spk_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + c_dim, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.pre_net = torch.nn.Linear(c_dim, hidden_channels) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x, x_lengths): + x = x.transpose(1,-1) + x = self.pre_net(x) + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.encoder(x * x_mask, x_mask) + x = self.proj(x) * x_mask + return x, x_mask + + +def pad_v2(input_ele, mel_max_length=None): + if mel_max_length: + max_len = mel_max_length + else: + max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) + + out_list = list() + for i, batch in enumerate(input_ele): + if len(batch.shape) == 1: + one_batch_padded = F.pad( + batch, (0, max_len - batch.size(0)), "constant", 0.0 + ) + elif len(batch.shape) == 2: + one_batch_padded = F.pad( + batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 + ) + out_list.append(one_batch_padded) + out_padded = torch.stack(out_list) + return out_padded + + +class LengthRegulator(nn.Module): + """ Length Regulator """ + + def __init__(self): + super(LengthRegulator, self).__init__() + + def LR(self, x, duration, max_len): + x = torch.transpose(x, 1, 2) + output = list() + mel_len = list() + for batch, expand_target in zip(x, duration): + expanded = self.expand(batch, expand_target) + output.append(expanded) + mel_len.append(expanded.shape[0]) + + if max_len is not None: + output = pad_v2(output, max_len) + else: + output = pad_v2(output) + output = torch.transpose(output, 1, 2) + return output, torch.LongTensor(mel_len) + + def expand(self, batch, predicted): + predicted = torch.squeeze(predicted) + out = list() + + for i, vec in enumerate(batch): + expand_size = predicted[i].item() + state_info_index = torch.unsqueeze(torch.arange(0, expand_size), 1).float() + state_info_length = torch.unsqueeze(torch.Tensor([expand_size] * expand_size), 1).float() + state_info = torch.cat([state_info_index, state_info_length], 1).to(vec.device) + new_vec = vec.expand(max(int(expand_size), 0), -1) + new_vec = torch.cat([new_vec, state_info], 1) + out.append(new_vec) + out = torch.cat(out, 0) + return out + + def forward(self, x, duration, max_len): + output, mel_len = self.LR(x, duration, max_len) + return output, mel_len + + +class PriorDecoder(nn.Module): + def __init__(self, + out_bn_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_speakers=0, + spk_channels=0): + super().__init__() + self.out_bn_channels = out_bn_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(hidden_channels , hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_bn_channels, 1) + + if n_speakers != 0: + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, x_lengths, spk_emb=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x) * x_mask + + if (spk_emb is not None): + x = x + self.cond(spk_emb) + + x = self.decoder(x * x_mask, x_mask) + + bn = self.proj(x) * x_mask + + return bn, x_mask + + +class Decoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_speakers=0, + spk_channels=0, + in_channels=None): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(in_channels if in_channels is not None else hidden_channels, hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + if n_speakers != 0: + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, x_lengths, spk_emb=None): + x = torch.detach(x) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x) * x_mask + + if (spk_emb is not None): + x = x + self.cond(spk_emb) + + x = self.decoder(x * x_mask, x_mask) + + x = self.proj(x) * x_mask + + return x, x_mask + +class F0Decoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_speakers=0, + spk_channels=0, + in_channels=None): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(in_channels if in_channels is not None else hidden_channels, hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.f0_prenet = nn.Conv1d(1, hidden_channels , 3, padding=1) + + if n_speakers != 0: + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, norm_f0, x_lengths, spk_emb=None): + x = torch.detach(x) + x += self.f0_prenet(norm_f0) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x) * x_mask + + if (spk_emb is not None): + x = x + self.cond(spk_emb) + + x = self.decoder(x * x_mask, x_mask) + + x = self.proj(x) * x_mask + + return x, x_mask + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x): + x = self.conv_layers[0](x) + x = self.norm_layers[0](x) + x = self.relu_drop(x) + + for i in range(1, self.n_layers): + x_ = self.conv_layers[i](x) + x_ = self.norm_layers[i](x_) + x_ = self.relu_drop(x_) + x = (x + x_) / 2 + x = self.proj(x) + return x + + +class PosteriorEncoder(nn.Module): + def __init__(self, + hps, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, n_speakers=hps.data.n_speakers, spk_channels=hps.model.spk_channels) + # self.enc = ConvReluNorm(hidden_channels, + # hidden_channels, + # hidden_channels, + # kernel_size, + # n_layers, + # 0.1) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + return stats, x_mask + + +class ResBlock3(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock3, self).__init__() + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator_Harm(torch.nn.Module): + def __init__(self, hps): + super(Generator_Harm, self).__init__() + self.hps = hps + + self.prenet = Conv1d(hps.model.hidden_channels, hps.model.hidden_channels, 3, padding=1) + + self.net = ConvReluNorm(hps.model.hidden_channels, + hps.model.hidden_channels, + hps.model.hidden_channels, + hps.model.kernel_size, + 8, + hps.model.p_dropout) + + # self.rnn = nn.LSTM(input_size=hps.model.hidden_channels, + # hidden_size=hps.model.hidden_channels, + # num_layers=1, + # bias=True, + # batch_first=True, + # dropout=0.5, + # bidirectional=True) + self.postnet = Conv1d(hps.model.hidden_channels, hps.model.n_harmonic + 1, 3, padding=1) + + def forward(self, f0, harm, mask): + pitch = f0.transpose(1, 2) + harm = self.prenet(harm) + + harm = self.net(harm) * mask + # harm = harm.transpose(1, 2) + # harm, (hs, hc) = self.rnn(harm) + # harm = harm.transpose(1, 2) + + harm = self.postnet(harm) + harm = harm.transpose(1, 2) + param = harm + + param = scale_function(param) + total_amp = param[..., :1] + amplitudes = param[..., 1:] + amplitudes = remove_above_nyquist( + amplitudes, + pitch, + self.hps.data.sampling_rate, + ) + amplitudes /= amplitudes.sum(-1, keepdim=True) + amplitudes *= total_amp + + amplitudes = upsample(amplitudes, self.hps.data.hop_length) + pitch = upsample(pitch, self.hps.data.hop_length) + + n_harmonic = amplitudes.shape[-1] + omega = torch.cumsum(2 * math.pi * pitch / self.hps.data.sampling_rate, 1) + omegas = omega * torch.arange(1, n_harmonic + 1).to(omega) + signal_harmonics = (torch.sin(omegas) * amplitudes) + signal_harmonics = signal_harmonics.transpose(1, 2) + return signal_harmonics + + +class Generator(torch.nn.Module): + def __init__(self, hps, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, + upsample_initial_channel, upsample_kernel_sizes, n_speakers=0, spk_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + self.upsample_rates = upsample_rates + self.n_speakers = n_speakers + + resblock = modules.ResBlock1 if resblock == '1' else modules.R + + self.downs = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + i = len(upsample_rates) - 1 - i + u = upsample_rates[i] + k = upsample_kernel_sizes[i] + # print("down: ",upsample_initial_channel//(2**(i+1))," -> ", upsample_initial_channel//(2**i)) + self.downs.append(weight_norm( + Conv1d(hps.model.n_harmonic + 2, hps.model.n_harmonic + 2, + k, u, padding=k // 2))) + + self.resblocks_downs = nn.ModuleList() + for i in range(len(self.downs)): + j = len(upsample_rates) - 1 - i + self.resblocks_downs.append(ResBlock3(hps.model.n_harmonic + 2, 3, (1, 3))) + + self.concat_pre = Conv1d(upsample_initial_channel + hps.model.n_harmonic + 2, upsample_initial_channel, 3, 1, + padding=1) + self.concat_conv = nn.ModuleList() + for i in range(len(upsample_rates)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.concat_conv.append(Conv1d(ch + hps.model.n_harmonic + 2, ch, 3, 1, padding=1, bias=False)) + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if self.n_speakers != 0: + self.cond = nn.Conv1d(spk_channels, upsample_initial_channel, 1) + + def forward(self, x, ddsp, g=None): + + x = self.conv_pre(x) + + if g is not None: + x = x + self.cond(g) + + se = ddsp + res_features = [se] + for i in range(self.num_upsamples): + in_size = se.size(2) + se = self.downs[i](se) + se = self.resblocks_downs[i](se) + up_rate = self.upsample_rates[self.num_upsamples - 1 - i] + se = se[:, :, : in_size // up_rate] + res_features.append(se) + + x = torch.cat([x, se], 1) + x = self.concat_pre(x) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + in_size = x.size(2) + x = self.ups[i](x) + # 保证维度正确,丢掉多余通道 + x = x[:, :, : in_size * self.upsample_rates[i]] + + x = torch.cat([x, res_features[self.num_upsamples - 1 - i]], 1) + x = self.concat_conv[i](x) + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +from scipy.signal import get_window +def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): + if win_type == 'None' or win_type is None: + window = np.ones(win_len) + else: + window = get_window(win_type, win_len, fftbins=True)#**0.5 + + N = fft_len + fourier_basis = np.fft.rfft(np.eye(N))[:win_len] + real_kernel = np.real(fourier_basis) + imag_kernel = np.imag(fourier_basis) + kernel = np.concatenate([real_kernel, imag_kernel], 1).T + + if invers : + kernel = np.linalg.pinv(kernel).T + + kernel = kernel*window + kernel = kernel[:, None, :] + return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) + + +class ConviSTFT(nn.Module): + def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): + super(ConviSTFT, self).__init__() + if fft_len == None: + self.fft_len = np.int(2**np.ceil(np.log2(win_len))) + else: + self.fft_len = fft_len + kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) + #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) + self.register_buffer('weight', kernel) + self.feature_type = feature_type + self.win_type = win_type + self.win_len = win_len + self.stride = win_inc + self.dim = self.fft_len + self.register_buffer('window', window) + self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) + + def forward(self, inputs, t): + outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) + coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) + outputs = outputs/(coff+1e-8) + #outputs = torch.where(coff == 0, outputs, outputs/coff) + outputs = outputs[...,768:-768] + return outputs + + +class Generator_Noise(torch.nn.Module): + def __init__(self, hps): + super(Generator_Noise, self).__init__() + self.hps = hps + self.win_size = hps.data.win_size + self.hop_size = hps.data.hop_length + self.fft_size = hps.data.n_fft + self.istft_pre = Conv1d(hps.model.hidden_channels, hps.model.hidden_channels, 3, padding=1) + + self.net = ConvReluNorm(hps.model.hidden_channels, + hps.model.hidden_channels, + hps.model.hidden_channels, + hps.model.kernel_size, + 8, + hps.model.p_dropout) + + self.istft_amplitude = torch.nn.Conv1d(hps.model.hidden_channels, self.fft_size // 2 + 1, 1, 1) + self.window = torch.hann_window(self.win_size) + self.istft = ConviSTFT(self.win_size, self.hop_size ,self.fft_size) + + def forward(self, x, mask, t_window): + istft_x = x + istft_x = self.istft_pre(istft_x) + + istft_x = self.net(istft_x) * mask + + amp = self.istft_amplitude(istft_x).unsqueeze(-1) + phase = (torch.rand(amp.shape) * 2 * 3.14 - 3.14).to(amp) + + real = amp * torch.cos(phase) + imag = amp * torch.sin(phase) + + ''' + spec = torch.cat([real, imag], 1).squeeze(3) + print(spec.shape) + istft_x = self.istft(spec) + + spec = torch.cat([real, imag], 3) + istft_x = torch.istft(spec, self.fft_size, self.hop_size, self.win_size, self.window.to(amp), True, + length=x.shape[2] * self.hop_size, return_complex=False) + ''' + spec = torch.cat([real, imag], 1).squeeze(3) + istft_x = self.istft(spec, t_window) + + return istft_x + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiFrequencyDiscriminator(nn.Module): + def __init__(self, + hop_lengths=[128, 256, 512], + hidden_channels=[256, 512, 512], + domain='double', mel_scale=True): + super(MultiFrequencyDiscriminator, self).__init__() + + self.stfts = nn.ModuleList([ + TorchSTFT(fft_size=x * 4, hop_size=x, win_size=x * 4, + normalized=True, domain=domain, mel_scale=mel_scale) + for x in hop_lengths]) + + self.domain = domain + if domain == 'double': + self.discriminators = nn.ModuleList([ + BaseFrequenceDiscriminator(2, c) + for x, c in zip(hop_lengths, hidden_channels)]) + else: + self.discriminators = nn.ModuleList([ + BaseFrequenceDiscriminator(1, c) + for x, c in zip(hop_lengths, hidden_channels)]) + + def forward(self, x): + scores, feats = list(), list() + for stft, layer in zip(self.stfts, self.discriminators): + # print(stft) + mag, phase = stft.transform(x.squeeze()) + if self.domain == 'double': + mag = torch.stack(torch.chunk(mag, 2, dim=1), dim=1) + else: + mag = mag.unsqueeze(1) + + score, feat = layer(mag) + scores.append(score) + feats.append(feat) + return scores, feats + + +class BaseFrequenceDiscriminator(nn.Module): + def __init__(self, in_channels, hidden_channels=512): + super(BaseFrequenceDiscriminator, self).__init__() + + self.discriminator = nn.ModuleList() + self.discriminator += [ + nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + in_channels, hidden_channels // 32, + kernel_size=(3, 3), stride=(1, 1))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 32, hidden_channels // 16, + kernel_size=(3, 3), stride=(2, 2))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 16, hidden_channels // 8, + kernel_size=(3, 3), stride=(1, 1))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 8, hidden_channels // 4, + kernel_size=(3, 3), stride=(2, 2))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 4, hidden_channels // 2, + kernel_size=(3, 3), stride=(1, 1))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 2, hidden_channels, + kernel_size=(3, 3), stride=(2, 2))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels, 1, + kernel_size=(3, 3), stride=(1, 1))) + ) + ] + + def forward(self, x): + hiddens = [] + for layer in self.discriminator: + x = layer(x) + hiddens.append(x) + return x, hiddens[-1] + + +class Discriminator(torch.nn.Module): + def __init__(self, hps, use_spectral_norm=False): + super(Discriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + # self.disc_multfrequency = MultiFrequencyDiscriminator(hop_lengths=[int(hps.data.sampling_rate * 2.5 / 1000), + # int(hps.data.sampling_rate * 5 / 1000), + # int(hps.data.sampling_rate * 7.5 / 1000), + # int(hps.data.sampling_rate * 10 / 1000), + # int(hps.data.sampling_rate * 12.5 / 1000), + # int(hps.data.sampling_rate * 15 / 1000)], + # hidden_channels=[256, 256, 256, 256, 256]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + # scores_r, fmaps_r = self.disc_multfrequency(y) + # scores_g, fmaps_g = self.disc_multfrequency(y_hat) + # for i in range(len(scores_r)): + # y_d_rs.append(scores_r[i]) + # y_d_gs.append(scores_g[i]) + # fmap_rs.append(fmaps_r[i]) + # fmap_gs.append(fmaps_g[i]) + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class SynthesizerTrn(nn.Module): + """ + Model + """ + + def __init__(self, hps): + super().__init__() + self.hps = hps + + self.text_encoder = TextEncoder( + hps.data.c_dim, + hps.model.prior_hidden_channels, + hps.model.prior_hidden_channels, + hps.model.prior_filter_channels, + hps.model.prior_n_heads, + hps.model.prior_n_layers, + hps.model.prior_kernel_size, + hps.model.prior_p_dropout) + + self.decoder = PriorDecoder( + hps.model.hidden_channels * 2, + hps.model.prior_hidden_channels, + hps.model.prior_filter_channels, + hps.model.prior_n_heads, + hps.model.prior_n_layers, + hps.model.prior_kernel_size, + hps.model.prior_p_dropout, + n_speakers=hps.data.n_speakers, + spk_channels=hps.model.spk_channels + ) + + self.f0_decoder = F0Decoder( + 1, + hps.model.prior_hidden_channels, + hps.model.prior_filter_channels, + hps.model.prior_n_heads, + hps.model.prior_n_layers, + hps.model.prior_kernel_size, + hps.model.prior_p_dropout, + n_speakers=hps.data.n_speakers, + spk_channels=hps.model.spk_channels + ) + + self.mel_decoder = Decoder( + hps.data.acoustic_dim, + hps.model.prior_hidden_channels, + hps.model.prior_filter_channels, + hps.model.prior_n_heads, + hps.model.prior_n_layers, + hps.model.prior_kernel_size, + hps.model.prior_p_dropout, + n_speakers=hps.data.n_speakers, + spk_channels=hps.model.spk_channels + ) + + self.posterior_encoder = PosteriorEncoder( + hps, + hps.data.acoustic_dim, + hps.model.hidden_channels, + hps.model.hidden_channels, 3, 1, 8) + + self.dropout = nn.Dropout(0.2) + + self.LR = LengthRegulator() + + self.dec = Generator(hps, + hps.model.hidden_channels, + hps.model.resblock, + hps.model.resblock_kernel_sizes, + hps.model.resblock_dilation_sizes, + hps.model.upsample_rates, + hps.model.upsample_initial_channel, + hps.model.upsample_kernel_sizes, + n_speakers=hps.data.n_speakers, + spk_channels=hps.model.spk_channels) + + self.dec_harm = Generator_Harm(hps) + + self.dec_noise = Generator_Noise(hps) + + self.f0_prenet = nn.Conv1d(1, hps.model.prior_hidden_channels , 3, padding=1) + self.energy_prenet = nn.Conv1d(1, hps.model.prior_hidden_channels , 3, padding=1) + self.mel_prenet = nn.Conv1d(hps.data.acoustic_dim, hps.model.prior_hidden_channels , 3, padding=1) + + if hps.data.n_speakers > 1: + self.emb_spk = nn.Embedding(hps.data.n_speakers, hps.model.spk_channels) + self.flow = modules.ResidualCouplingBlock(hps.model.prior_hidden_channels, hps.model.hidden_channels, 5, 1, 4,n_speakers=hps.data.n_speakers, gin_channels=hps.model.spk_channels) + + def forward(self, c, f0, mel2ph, t_window, noise=None, g=None): + if len(g.shape) == 2: + g = g.squeeze(0) + if len(f0.shape) == 2: + f0 = f0.unsqueeze(0) + g = self.emb_spk(g).unsqueeze(-1) # [b, h, 1] + + decoder_inp = F.pad(c, [0, 0, 1, 0]) + mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]]) + c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H] + + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + + # Encoder + decoder_input, x_mask = self.text_encoder(c, c_lengths) + y_lengths = c_lengths + + LF0 = 2595. * torch.log10(1. + f0 / 700.) + LF0 = LF0 / 500 + + # aam + predict_mel, predict_bn_mask = self.mel_decoder(decoder_input + self.f0_prenet(LF0), y_lengths, spk_emb=g) + predict_energy = predict_mel.sum(1).unsqueeze(1) / self.hps.data.acoustic_dim + + decoder_input = decoder_input + \ + self.f0_prenet(LF0) + \ + self.energy_prenet(predict_energy) + \ + self.mel_prenet(predict_mel) + decoder_output, y_mask = self.decoder(decoder_input, y_lengths, spk_emb=g) + + prior_info = decoder_output + + m_p = prior_info[:, :self.hps.model.hidden_channels, :] + logs_p = prior_info[:, self.hps.model.hidden_channels:, :] + z_p = m_p + torch.exp(logs_p) * noise + z = self.flow(z_p, y_mask, g=g, reverse=True) + + prior_z = z + + noise_x = self.dec_noise(prior_z, y_mask, t_window) + + harm_x = self.dec_harm(f0, prior_z, y_mask) + + pitch = upsample(f0.transpose(1, 2), self.hps.data.hop_length) + omega = torch.cumsum(2 * math.pi * pitch / self.hps.data.sampling_rate, 1) + sin = torch.sin(omega).transpose(1, 2) + + decoder_condition = torch.cat([harm_x, noise_x, sin], axis=1) + + # dsp based HiFiGAN vocoder + o = self.dec(prior_z, decoder_condition, g=g) + + return o