1046 lines
37 KiB
Python
1046 lines
37 KiB
Python
import sys
|
|
import copy
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
|
import numpy as np
|
|
|
|
sys.path.append('../..')
|
|
import modules.commons as commons
|
|
import modules.modules as modules
|
|
import modules.attentions as attentions
|
|
|
|
from modules.commons import init_weights, get_padding
|
|
|
|
from modules.ddsp import mlp, gru, scale_function, remove_above_nyquist, upsample
|
|
from modules.ddsp import harmonic_synth, amp_to_impulse_response, fft_convolve
|
|
from modules.ddsp import resample
|
|
import utils
|
|
|
|
from modules.stft import TorchSTFT
|
|
|
|
import torch.distributions as D
|
|
|
|
from modules.losses import (
|
|
generator_loss,
|
|
discriminator_loss,
|
|
feature_loss,
|
|
kl_loss
|
|
)
|
|
|
|
LRELU_SLOPE = 0.1
|
|
|
|
|
|
class PostF0Decoder(nn.Module):
|
|
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, spk_channels=0):
|
|
super().__init__()
|
|
|
|
self.in_channels = in_channels
|
|
self.filter_channels = filter_channels
|
|
self.kernel_size = kernel_size
|
|
self.p_dropout = p_dropout
|
|
self.gin_channels = spk_channels
|
|
|
|
self.drop = nn.Dropout(p_dropout)
|
|
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
|
self.norm_1 = modules.LayerNorm(filter_channels)
|
|
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
|
self.norm_2 = modules.LayerNorm(filter_channels)
|
|
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
|
|
|
if spk_channels != 0:
|
|
self.cond = nn.Conv1d(spk_channels, in_channels, 1)
|
|
|
|
def forward(self, x, x_mask, g=None):
|
|
x = torch.detach(x)
|
|
if g is not None:
|
|
g = torch.detach(g)
|
|
x = x + self.cond(g)
|
|
x = self.conv_1(x * x_mask)
|
|
x = torch.relu(x)
|
|
x = self.norm_1(x)
|
|
x = self.drop(x)
|
|
x = self.conv_2(x * x_mask)
|
|
x = torch.relu(x)
|
|
x = self.norm_2(x)
|
|
x = self.drop(x)
|
|
x = self.proj(x * x_mask)
|
|
return x * x_mask
|
|
|
|
|
|
class TextEncoder(nn.Module):
|
|
def __init__(self,
|
|
c_dim,
|
|
out_channels,
|
|
hidden_channels,
|
|
filter_channels,
|
|
n_heads,
|
|
n_layers,
|
|
kernel_size,
|
|
p_dropout):
|
|
super().__init__()
|
|
self.out_channels = out_channels
|
|
self.hidden_channels = hidden_channels
|
|
self.filter_channels = filter_channels
|
|
self.n_heads = n_heads
|
|
self.n_layers = n_layers
|
|
self.kernel_size = kernel_size
|
|
self.p_dropout = p_dropout
|
|
|
|
self.pre_net = torch.nn.Linear(c_dim, hidden_channels)
|
|
|
|
self.encoder = attentions.Encoder(
|
|
hidden_channels,
|
|
filter_channels,
|
|
n_heads,
|
|
n_layers,
|
|
kernel_size,
|
|
p_dropout)
|
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
|
|
|
def forward(self, x, x_lengths):
|
|
x = x.transpose(1,-1)
|
|
x = self.pre_net(x)
|
|
x = torch.transpose(x, 1, -1) # [b, h, t]
|
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
x = self.encoder(x * x_mask, x_mask)
|
|
x = self.proj(x) * x_mask
|
|
return x, x_mask
|
|
|
|
|
|
def pad_v2(input_ele, mel_max_length=None):
|
|
if mel_max_length:
|
|
max_len = mel_max_length
|
|
else:
|
|
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
|
|
|
|
out_list = list()
|
|
for i, batch in enumerate(input_ele):
|
|
if len(batch.shape) == 1:
|
|
one_batch_padded = F.pad(
|
|
batch, (0, max_len - batch.size(0)), "constant", 0.0
|
|
)
|
|
elif len(batch.shape) == 2:
|
|
one_batch_padded = F.pad(
|
|
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
|
|
)
|
|
out_list.append(one_batch_padded)
|
|
out_padded = torch.stack(out_list)
|
|
return out_padded
|
|
|
|
|
|
class LengthRegulator(nn.Module):
|
|
""" Length Regulator """
|
|
|
|
def __init__(self):
|
|
super(LengthRegulator, self).__init__()
|
|
|
|
def LR(self, x, duration, max_len):
|
|
x = torch.transpose(x, 1, 2)
|
|
output = list()
|
|
mel_len = list()
|
|
for batch, expand_target in zip(x, duration):
|
|
expanded = self.expand(batch, expand_target)
|
|
output.append(expanded)
|
|
mel_len.append(expanded.shape[0])
|
|
|
|
if max_len is not None:
|
|
output = pad_v2(output, max_len)
|
|
else:
|
|
output = pad_v2(output)
|
|
output = torch.transpose(output, 1, 2)
|
|
return output, torch.LongTensor(mel_len)
|
|
|
|
def expand(self, batch, predicted):
|
|
predicted = torch.squeeze(predicted)
|
|
out = list()
|
|
|
|
for i, vec in enumerate(batch):
|
|
expand_size = predicted[i].item()
|
|
state_info_index = torch.unsqueeze(torch.arange(0, expand_size), 1).float()
|
|
state_info_length = torch.unsqueeze(torch.Tensor([expand_size] * expand_size), 1).float()
|
|
state_info = torch.cat([state_info_index, state_info_length], 1).to(vec.device)
|
|
new_vec = vec.expand(max(int(expand_size), 0), -1)
|
|
new_vec = torch.cat([new_vec, state_info], 1)
|
|
out.append(new_vec)
|
|
out = torch.cat(out, 0)
|
|
return out
|
|
|
|
def forward(self, x, duration, max_len):
|
|
output, mel_len = self.LR(x, duration, max_len)
|
|
return output, mel_len
|
|
|
|
|
|
class PriorDecoder(nn.Module):
|
|
def __init__(self,
|
|
out_bn_channels,
|
|
hidden_channels,
|
|
filter_channels,
|
|
n_heads,
|
|
n_layers,
|
|
kernel_size,
|
|
p_dropout,
|
|
n_speakers=0,
|
|
spk_channels=0):
|
|
super().__init__()
|
|
self.out_bn_channels = out_bn_channels
|
|
self.hidden_channels = hidden_channels
|
|
self.filter_channels = filter_channels
|
|
self.n_heads = n_heads
|
|
self.n_layers = n_layers
|
|
self.kernel_size = kernel_size
|
|
self.p_dropout = p_dropout
|
|
self.spk_channels = spk_channels
|
|
|
|
self.prenet = nn.Conv1d(hidden_channels , hidden_channels, 3, padding=1)
|
|
self.decoder = attentions.FFT(
|
|
hidden_channels,
|
|
filter_channels,
|
|
n_heads,
|
|
n_layers,
|
|
kernel_size,
|
|
p_dropout)
|
|
self.proj = nn.Conv1d(hidden_channels, out_bn_channels, 1)
|
|
|
|
if n_speakers != 0:
|
|
self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
|
|
|
|
def forward(self, x, x_lengths, spk_emb=None):
|
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
|
|
x = self.prenet(x) * x_mask
|
|
|
|
if (spk_emb is not None):
|
|
x = x + self.cond(spk_emb)
|
|
|
|
x = self.decoder(x * x_mask, x_mask)
|
|
|
|
bn = self.proj(x) * x_mask
|
|
|
|
return bn, x_mask
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self,
|
|
out_channels,
|
|
hidden_channels,
|
|
filter_channels,
|
|
n_heads,
|
|
n_layers,
|
|
kernel_size,
|
|
p_dropout,
|
|
n_speakers=0,
|
|
spk_channels=0,
|
|
in_channels=None):
|
|
super().__init__()
|
|
self.out_channels = out_channels
|
|
self.hidden_channels = hidden_channels
|
|
self.filter_channels = filter_channels
|
|
self.n_heads = n_heads
|
|
self.n_layers = n_layers
|
|
self.kernel_size = kernel_size
|
|
self.p_dropout = p_dropout
|
|
self.spk_channels = spk_channels
|
|
|
|
self.prenet = nn.Conv1d(in_channels if in_channels is not None else hidden_channels, hidden_channels, 3, padding=1)
|
|
self.decoder = attentions.FFT(
|
|
hidden_channels,
|
|
filter_channels,
|
|
n_heads,
|
|
n_layers,
|
|
kernel_size,
|
|
p_dropout)
|
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
|
|
|
if n_speakers != 0:
|
|
self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
|
|
|
|
def forward(self, x, x_lengths, spk_emb=None):
|
|
x = torch.detach(x)
|
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
|
|
x = self.prenet(x) * x_mask
|
|
|
|
if (spk_emb is not None):
|
|
x = x + self.cond(spk_emb)
|
|
|
|
x = self.decoder(x * x_mask, x_mask)
|
|
|
|
x = self.proj(x) * x_mask
|
|
|
|
return x, x_mask
|
|
|
|
class F0Decoder(nn.Module):
|
|
def __init__(self,
|
|
out_channels,
|
|
hidden_channels,
|
|
filter_channels,
|
|
n_heads,
|
|
n_layers,
|
|
kernel_size,
|
|
p_dropout,
|
|
n_speakers=0,
|
|
spk_channels=0,
|
|
in_channels=None):
|
|
super().__init__()
|
|
self.out_channels = out_channels
|
|
self.hidden_channels = hidden_channels
|
|
self.filter_channels = filter_channels
|
|
self.n_heads = n_heads
|
|
self.n_layers = n_layers
|
|
self.kernel_size = kernel_size
|
|
self.p_dropout = p_dropout
|
|
self.spk_channels = spk_channels
|
|
|
|
self.prenet = nn.Conv1d(in_channels if in_channels is not None else hidden_channels, hidden_channels, 3, padding=1)
|
|
self.decoder = attentions.FFT(
|
|
hidden_channels,
|
|
filter_channels,
|
|
n_heads,
|
|
n_layers,
|
|
kernel_size,
|
|
p_dropout)
|
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
|
self.f0_prenet = nn.Conv1d(1, hidden_channels , 3, padding=1)
|
|
|
|
if n_speakers != 0:
|
|
self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
|
|
|
|
def forward(self, x, norm_f0, x_lengths, spk_emb=None):
|
|
x = torch.detach(x)
|
|
x += self.f0_prenet(norm_f0)
|
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
|
|
x = self.prenet(x) * x_mask
|
|
|
|
if (spk_emb is not None):
|
|
x = x + self.cond(spk_emb)
|
|
|
|
x = self.decoder(x * x_mask, x_mask)
|
|
|
|
x = self.proj(x) * x_mask
|
|
|
|
return x, x_mask
|
|
|
|
|
|
class ConvReluNorm(nn.Module):
|
|
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.hidden_channels = hidden_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size
|
|
self.n_layers = n_layers
|
|
self.p_dropout = p_dropout
|
|
assert n_layers > 1, "Number of layers should be larger than 0."
|
|
|
|
self.conv_layers = nn.ModuleList()
|
|
self.norm_layers = nn.ModuleList()
|
|
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
|
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
self.relu_drop = nn.Sequential(
|
|
nn.ReLU(),
|
|
nn.Dropout(p_dropout))
|
|
for _ in range(n_layers - 1):
|
|
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
|
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
|
self.proj.weight.data.zero_()
|
|
self.proj.bias.data.zero_()
|
|
|
|
def forward(self, x):
|
|
x = self.conv_layers[0](x)
|
|
x = self.norm_layers[0](x)
|
|
x = self.relu_drop(x)
|
|
|
|
for i in range(1, self.n_layers):
|
|
x_ = self.conv_layers[i](x)
|
|
x_ = self.norm_layers[i](x_)
|
|
x_ = self.relu_drop(x_)
|
|
x = (x + x_) / 2
|
|
x = self.proj(x)
|
|
return x
|
|
|
|
|
|
class PosteriorEncoder(nn.Module):
|
|
def __init__(self,
|
|
hps,
|
|
in_channels,
|
|
out_channels,
|
|
hidden_channels,
|
|
kernel_size,
|
|
dilation_rate,
|
|
n_layers):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.hidden_channels = hidden_channels
|
|
self.kernel_size = kernel_size
|
|
self.dilation_rate = dilation_rate
|
|
self.n_layers = n_layers
|
|
|
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
|
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, n_speakers=hps.data.n_speakers, spk_channels=hps.model.spk_channels)
|
|
# self.enc = ConvReluNorm(hidden_channels,
|
|
# hidden_channels,
|
|
# hidden_channels,
|
|
# kernel_size,
|
|
# n_layers,
|
|
# 0.1)
|
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
|
|
|
def forward(self, x, x_lengths, g=None):
|
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
x = self.pre(x) * x_mask
|
|
x = self.enc(x, x_mask, g=g)
|
|
stats = self.proj(x) * x_mask
|
|
return stats, x_mask
|
|
|
|
|
|
class ResBlock3(torch.nn.Module):
|
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
|
super(ResBlock3, self).__init__()
|
|
self.convs = nn.ModuleList([
|
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
|
padding=get_padding(kernel_size, dilation[0])))
|
|
])
|
|
self.convs.apply(init_weights)
|
|
|
|
def forward(self, x, x_mask=None):
|
|
for c in self.convs:
|
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
|
if x_mask is not None:
|
|
xt = xt * x_mask
|
|
xt = c(xt)
|
|
x = xt + x
|
|
if x_mask is not None:
|
|
x = x * x_mask
|
|
return x
|
|
|
|
def remove_weight_norm(self):
|
|
for l in self.convs:
|
|
remove_weight_norm(l)
|
|
|
|
|
|
class Generator_Harm(torch.nn.Module):
|
|
def __init__(self, hps):
|
|
super(Generator_Harm, self).__init__()
|
|
self.hps = hps
|
|
|
|
self.prenet = Conv1d(hps.model.hidden_channels, hps.model.hidden_channels, 3, padding=1)
|
|
|
|
self.net = ConvReluNorm(hps.model.hidden_channels,
|
|
hps.model.hidden_channels,
|
|
hps.model.hidden_channels,
|
|
hps.model.kernel_size,
|
|
8,
|
|
hps.model.p_dropout)
|
|
|
|
# self.rnn = nn.LSTM(input_size=hps.model.hidden_channels,
|
|
# hidden_size=hps.model.hidden_channels,
|
|
# num_layers=1,
|
|
# bias=True,
|
|
# batch_first=True,
|
|
# dropout=0.5,
|
|
# bidirectional=True)
|
|
self.postnet = Conv1d(hps.model.hidden_channels, hps.model.n_harmonic + 1, 3, padding=1)
|
|
|
|
def forward(self, f0, harm, mask):
|
|
pitch = f0.transpose(1, 2)
|
|
harm = self.prenet(harm)
|
|
|
|
harm = self.net(harm) * mask
|
|
# harm = harm.transpose(1, 2)
|
|
# harm, (hs, hc) = self.rnn(harm)
|
|
# harm = harm.transpose(1, 2)
|
|
|
|
harm = self.postnet(harm)
|
|
harm = harm.transpose(1, 2)
|
|
param = harm
|
|
|
|
param = scale_function(param)
|
|
total_amp = param[..., :1]
|
|
amplitudes = param[..., 1:]
|
|
amplitudes = remove_above_nyquist(
|
|
amplitudes,
|
|
pitch,
|
|
self.hps.data.sampling_rate,
|
|
)
|
|
amplitudes /= amplitudes.sum(-1, keepdim=True)
|
|
amplitudes *= total_amp
|
|
|
|
amplitudes = upsample(amplitudes, self.hps.data.hop_length)
|
|
pitch = upsample(pitch, self.hps.data.hop_length)
|
|
|
|
n_harmonic = amplitudes.shape[-1]
|
|
omega = torch.cumsum(2 * math.pi * pitch / self.hps.data.sampling_rate, 1)
|
|
omegas = omega * torch.arange(1, n_harmonic + 1).to(omega)
|
|
signal_harmonics = (torch.sin(omegas) * amplitudes)
|
|
signal_harmonics = signal_harmonics.transpose(1, 2)
|
|
return signal_harmonics
|
|
|
|
|
|
class Generator(torch.nn.Module):
|
|
def __init__(self, hps, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
|
upsample_initial_channel, upsample_kernel_sizes, n_speakers=0, spk_channels=0):
|
|
super(Generator, self).__init__()
|
|
self.num_kernels = len(resblock_kernel_sizes)
|
|
self.num_upsamples = len(upsample_rates)
|
|
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
|
self.upsample_rates = upsample_rates
|
|
self.n_speakers = n_speakers
|
|
|
|
resblock = modules.ResBlock1 if resblock == '1' else modules.R
|
|
|
|
self.downs = nn.ModuleList()
|
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
i = len(upsample_rates) - 1 - i
|
|
u = upsample_rates[i]
|
|
k = upsample_kernel_sizes[i]
|
|
# print("down: ",upsample_initial_channel//(2**(i+1))," -> ", upsample_initial_channel//(2**i))
|
|
self.downs.append(weight_norm(
|
|
Conv1d(hps.model.n_harmonic + 2, hps.model.n_harmonic + 2,
|
|
k, u, padding=k // 2)))
|
|
|
|
self.resblocks_downs = nn.ModuleList()
|
|
for i in range(len(self.downs)):
|
|
j = len(upsample_rates) - 1 - i
|
|
self.resblocks_downs.append(ResBlock3(hps.model.n_harmonic + 2, 3, (1, 3)))
|
|
|
|
self.concat_pre = Conv1d(upsample_initial_channel + hps.model.n_harmonic + 2, upsample_initial_channel, 3, 1,
|
|
padding=1)
|
|
self.concat_conv = nn.ModuleList()
|
|
for i in range(len(upsample_rates)):
|
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
|
self.concat_conv.append(Conv1d(ch + hps.model.n_harmonic + 2, ch, 3, 1, padding=1, bias=False))
|
|
|
|
self.ups = nn.ModuleList()
|
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
self.ups.append(weight_norm(
|
|
ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
|
|
k, u, padding=(k - u) // 2)))
|
|
|
|
self.resblocks = nn.ModuleList()
|
|
for i in range(len(self.ups)):
|
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
|
self.resblocks.append(resblock(ch, k, d))
|
|
|
|
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
|
self.ups.apply(init_weights)
|
|
|
|
if self.n_speakers != 0:
|
|
self.cond = nn.Conv1d(spk_channels, upsample_initial_channel, 1)
|
|
|
|
def forward(self, x, ddsp, g=None):
|
|
|
|
x = self.conv_pre(x)
|
|
|
|
if g is not None:
|
|
x = x + self.cond(g)
|
|
|
|
se = ddsp
|
|
res_features = [se]
|
|
for i in range(self.num_upsamples):
|
|
in_size = se.size(2)
|
|
se = self.downs[i](se)
|
|
se = self.resblocks_downs[i](se)
|
|
up_rate = self.upsample_rates[self.num_upsamples - 1 - i]
|
|
se = se[:, :, : in_size // up_rate]
|
|
res_features.append(se)
|
|
|
|
x = torch.cat([x, se], 1)
|
|
x = self.concat_pre(x)
|
|
|
|
for i in range(self.num_upsamples):
|
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
|
in_size = x.size(2)
|
|
x = self.ups[i](x)
|
|
# 保证维度正确,丢掉多余通道
|
|
x = x[:, :, : in_size * self.upsample_rates[i]]
|
|
|
|
x = torch.cat([x, res_features[self.num_upsamples - 1 - i]], 1)
|
|
x = self.concat_conv[i](x)
|
|
|
|
xs = None
|
|
for j in range(self.num_kernels):
|
|
if xs is None:
|
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
|
else:
|
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
|
x = xs / self.num_kernels
|
|
|
|
x = F.leaky_relu(x)
|
|
x = self.conv_post(x)
|
|
x = torch.tanh(x)
|
|
|
|
return x
|
|
|
|
def remove_weight_norm(self):
|
|
print('Removing weight norm...')
|
|
for l in self.ups:
|
|
remove_weight_norm(l)
|
|
for l in self.resblocks:
|
|
l.remove_weight_norm()
|
|
|
|
|
|
from scipy.signal import get_window
|
|
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
|
|
if win_type == 'None' or win_type is None:
|
|
window = np.ones(win_len)
|
|
else:
|
|
window = get_window(win_type, win_len, fftbins=True)#**0.5
|
|
|
|
N = fft_len
|
|
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
|
|
real_kernel = np.real(fourier_basis)
|
|
imag_kernel = np.imag(fourier_basis)
|
|
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
|
|
|
|
if invers :
|
|
kernel = np.linalg.pinv(kernel).T
|
|
|
|
kernel = kernel*window
|
|
kernel = kernel[:, None, :]
|
|
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32))
|
|
|
|
|
|
class ConviSTFT(nn.Module):
|
|
def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
|
|
super(ConviSTFT, self).__init__()
|
|
if fft_len == None:
|
|
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
|
|
else:
|
|
self.fft_len = fft_len
|
|
kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
|
|
#self.weight = nn.Parameter(kernel, requires_grad=(not fix))
|
|
self.register_buffer('weight', kernel)
|
|
self.feature_type = feature_type
|
|
self.win_type = win_type
|
|
self.win_len = win_len
|
|
self.stride = win_inc
|
|
self.dim = self.fft_len
|
|
self.register_buffer('window', window)
|
|
self.register_buffer('enframe', torch.eye(win_len)[:,None,:])
|
|
|
|
def forward(self, inputs, t):
|
|
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
|
|
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
|
outputs = outputs/(coff+1e-8)
|
|
#outputs = torch.where(coff == 0, outputs, outputs/coff)
|
|
outputs = outputs[...,768:-768]
|
|
return outputs
|
|
|
|
|
|
class Generator_Noise(torch.nn.Module):
|
|
def __init__(self, hps):
|
|
super(Generator_Noise, self).__init__()
|
|
self.hps = hps
|
|
self.win_size = hps.data.win_size
|
|
self.hop_size = hps.data.hop_length
|
|
self.fft_size = hps.data.n_fft
|
|
self.istft_pre = Conv1d(hps.model.hidden_channels, hps.model.hidden_channels, 3, padding=1)
|
|
|
|
self.net = ConvReluNorm(hps.model.hidden_channels,
|
|
hps.model.hidden_channels,
|
|
hps.model.hidden_channels,
|
|
hps.model.kernel_size,
|
|
8,
|
|
hps.model.p_dropout)
|
|
|
|
self.istft_amplitude = torch.nn.Conv1d(hps.model.hidden_channels, self.fft_size // 2 + 1, 1, 1)
|
|
self.window = torch.hann_window(self.win_size)
|
|
self.istft = ConviSTFT(self.win_size, self.hop_size ,self.fft_size)
|
|
|
|
def forward(self, x, mask, t_window):
|
|
istft_x = x
|
|
istft_x = self.istft_pre(istft_x)
|
|
|
|
istft_x = self.net(istft_x) * mask
|
|
|
|
amp = self.istft_amplitude(istft_x).unsqueeze(-1)
|
|
phase = (torch.rand(amp.shape) * 2 * 3.14 - 3.14).to(amp)
|
|
|
|
real = amp * torch.cos(phase)
|
|
imag = amp * torch.sin(phase)
|
|
|
|
'''
|
|
spec = torch.cat([real, imag], 1).squeeze(3)
|
|
print(spec.shape)
|
|
istft_x = self.istft(spec)
|
|
|
|
spec = torch.cat([real, imag], 3)
|
|
istft_x = torch.istft(spec, self.fft_size, self.hop_size, self.win_size, self.window.to(amp), True,
|
|
length=x.shape[2] * self.hop_size, return_complex=False)
|
|
'''
|
|
spec = torch.cat([real, imag], 1).squeeze(3)
|
|
istft_x = self.istft(spec, t_window)
|
|
|
|
return istft_x
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, channels, eps=1e-5):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.eps = eps
|
|
|
|
self.gamma = nn.Parameter(torch.ones(channels))
|
|
self.beta = nn.Parameter(torch.zeros(channels))
|
|
|
|
def forward(self, x):
|
|
x = x.transpose(1, -1)
|
|
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
|
return x.transpose(1, -1)
|
|
|
|
|
|
class DiscriminatorP(torch.nn.Module):
|
|
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
|
super(DiscriminatorP, self).__init__()
|
|
self.period = period
|
|
self.use_spectral_norm = use_spectral_norm
|
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
|
self.convs = nn.ModuleList([
|
|
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
|
])
|
|
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
|
|
|
def forward(self, x):
|
|
fmap = []
|
|
|
|
# 1d to 2d
|
|
b, c, t = x.shape
|
|
if t % self.period != 0: # pad first
|
|
n_pad = self.period - (t % self.period)
|
|
x = F.pad(x, (0, n_pad), "reflect")
|
|
t = t + n_pad
|
|
x = x.view(b, c, t // self.period, self.period)
|
|
|
|
for l in self.convs:
|
|
x = l(x)
|
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
|
fmap.append(x)
|
|
x = self.conv_post(x)
|
|
fmap.append(x)
|
|
x = torch.flatten(x, 1, -1)
|
|
|
|
return x, fmap
|
|
|
|
|
|
class DiscriminatorS(torch.nn.Module):
|
|
def __init__(self, use_spectral_norm=False):
|
|
super(DiscriminatorS, self).__init__()
|
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
|
self.convs = nn.ModuleList([
|
|
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
|
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
|
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
|
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
|
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
|
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
|
])
|
|
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
|
|
|
def forward(self, x):
|
|
fmap = []
|
|
|
|
for l in self.convs:
|
|
x = l(x)
|
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
|
fmap.append(x)
|
|
x = self.conv_post(x)
|
|
fmap.append(x)
|
|
x = torch.flatten(x, 1, -1)
|
|
|
|
return x, fmap
|
|
|
|
|
|
class MultiFrequencyDiscriminator(nn.Module):
|
|
def __init__(self,
|
|
hop_lengths=[128, 256, 512],
|
|
hidden_channels=[256, 512, 512],
|
|
domain='double', mel_scale=True):
|
|
super(MultiFrequencyDiscriminator, self).__init__()
|
|
|
|
self.stfts = nn.ModuleList([
|
|
TorchSTFT(fft_size=x * 4, hop_size=x, win_size=x * 4,
|
|
normalized=True, domain=domain, mel_scale=mel_scale)
|
|
for x in hop_lengths])
|
|
|
|
self.domain = domain
|
|
if domain == 'double':
|
|
self.discriminators = nn.ModuleList([
|
|
BaseFrequenceDiscriminator(2, c)
|
|
for x, c in zip(hop_lengths, hidden_channels)])
|
|
else:
|
|
self.discriminators = nn.ModuleList([
|
|
BaseFrequenceDiscriminator(1, c)
|
|
for x, c in zip(hop_lengths, hidden_channels)])
|
|
|
|
def forward(self, x):
|
|
scores, feats = list(), list()
|
|
for stft, layer in zip(self.stfts, self.discriminators):
|
|
# print(stft)
|
|
mag, phase = stft.transform(x.squeeze())
|
|
if self.domain == 'double':
|
|
mag = torch.stack(torch.chunk(mag, 2, dim=1), dim=1)
|
|
else:
|
|
mag = mag.unsqueeze(1)
|
|
|
|
score, feat = layer(mag)
|
|
scores.append(score)
|
|
feats.append(feat)
|
|
return scores, feats
|
|
|
|
|
|
class BaseFrequenceDiscriminator(nn.Module):
|
|
def __init__(self, in_channels, hidden_channels=512):
|
|
super(BaseFrequenceDiscriminator, self).__init__()
|
|
|
|
self.discriminator = nn.ModuleList()
|
|
self.discriminator += [
|
|
nn.Sequential(
|
|
nn.ReflectionPad2d((1, 1, 1, 1)),
|
|
nn.utils.weight_norm(nn.Conv2d(
|
|
in_channels, hidden_channels // 32,
|
|
kernel_size=(3, 3), stride=(1, 1)))
|
|
),
|
|
nn.Sequential(
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.ReflectionPad2d((1, 1, 1, 1)),
|
|
nn.utils.weight_norm(nn.Conv2d(
|
|
hidden_channels // 32, hidden_channels // 16,
|
|
kernel_size=(3, 3), stride=(2, 2)))
|
|
),
|
|
nn.Sequential(
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.ReflectionPad2d((1, 1, 1, 1)),
|
|
nn.utils.weight_norm(nn.Conv2d(
|
|
hidden_channels // 16, hidden_channels // 8,
|
|
kernel_size=(3, 3), stride=(1, 1)))
|
|
),
|
|
nn.Sequential(
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.ReflectionPad2d((1, 1, 1, 1)),
|
|
nn.utils.weight_norm(nn.Conv2d(
|
|
hidden_channels // 8, hidden_channels // 4,
|
|
kernel_size=(3, 3), stride=(2, 2)))
|
|
),
|
|
nn.Sequential(
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.ReflectionPad2d((1, 1, 1, 1)),
|
|
nn.utils.weight_norm(nn.Conv2d(
|
|
hidden_channels // 4, hidden_channels // 2,
|
|
kernel_size=(3, 3), stride=(1, 1)))
|
|
),
|
|
nn.Sequential(
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.ReflectionPad2d((1, 1, 1, 1)),
|
|
nn.utils.weight_norm(nn.Conv2d(
|
|
hidden_channels // 2, hidden_channels,
|
|
kernel_size=(3, 3), stride=(2, 2)))
|
|
),
|
|
nn.Sequential(
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.ReflectionPad2d((1, 1, 1, 1)),
|
|
nn.utils.weight_norm(nn.Conv2d(
|
|
hidden_channels, 1,
|
|
kernel_size=(3, 3), stride=(1, 1)))
|
|
)
|
|
]
|
|
|
|
def forward(self, x):
|
|
hiddens = []
|
|
for layer in self.discriminator:
|
|
x = layer(x)
|
|
hiddens.append(x)
|
|
return x, hiddens[-1]
|
|
|
|
|
|
class Discriminator(torch.nn.Module):
|
|
def __init__(self, hps, use_spectral_norm=False):
|
|
super(Discriminator, self).__init__()
|
|
periods = [2, 3, 5, 7, 11]
|
|
|
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
|
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
|
self.discriminators = nn.ModuleList(discs)
|
|
# self.disc_multfrequency = MultiFrequencyDiscriminator(hop_lengths=[int(hps.data.sampling_rate * 2.5 / 1000),
|
|
# int(hps.data.sampling_rate * 5 / 1000),
|
|
# int(hps.data.sampling_rate * 7.5 / 1000),
|
|
# int(hps.data.sampling_rate * 10 / 1000),
|
|
# int(hps.data.sampling_rate * 12.5 / 1000),
|
|
# int(hps.data.sampling_rate * 15 / 1000)],
|
|
# hidden_channels=[256, 256, 256, 256, 256])
|
|
|
|
def forward(self, y, y_hat):
|
|
y_d_rs = []
|
|
y_d_gs = []
|
|
fmap_rs = []
|
|
fmap_gs = []
|
|
for i, d in enumerate(self.discriminators):
|
|
y_d_r, fmap_r = d(y)
|
|
y_d_g, fmap_g = d(y_hat)
|
|
y_d_rs.append(y_d_r)
|
|
y_d_gs.append(y_d_g)
|
|
fmap_rs.append(fmap_r)
|
|
fmap_gs.append(fmap_g)
|
|
# scores_r, fmaps_r = self.disc_multfrequency(y)
|
|
# scores_g, fmaps_g = self.disc_multfrequency(y_hat)
|
|
# for i in range(len(scores_r)):
|
|
# y_d_rs.append(scores_r[i])
|
|
# y_d_gs.append(scores_g[i])
|
|
# fmap_rs.append(fmaps_r[i])
|
|
# fmap_gs.append(fmaps_g[i])
|
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
|
|
|
|
class SynthesizerTrn(nn.Module):
|
|
"""
|
|
Model
|
|
"""
|
|
|
|
def __init__(self, hps):
|
|
super().__init__()
|
|
self.hps = hps
|
|
|
|
self.text_encoder = TextEncoder(
|
|
hps.data.c_dim,
|
|
hps.model.prior_hidden_channels,
|
|
hps.model.prior_hidden_channels,
|
|
hps.model.prior_filter_channels,
|
|
hps.model.prior_n_heads,
|
|
hps.model.prior_n_layers,
|
|
hps.model.prior_kernel_size,
|
|
hps.model.prior_p_dropout)
|
|
|
|
self.decoder = PriorDecoder(
|
|
hps.model.hidden_channels * 2,
|
|
hps.model.prior_hidden_channels,
|
|
hps.model.prior_filter_channels,
|
|
hps.model.prior_n_heads,
|
|
hps.model.prior_n_layers,
|
|
hps.model.prior_kernel_size,
|
|
hps.model.prior_p_dropout,
|
|
n_speakers=hps.data.n_speakers,
|
|
spk_channels=hps.model.spk_channels
|
|
)
|
|
|
|
self.f0_decoder = F0Decoder(
|
|
1,
|
|
hps.model.prior_hidden_channels,
|
|
hps.model.prior_filter_channels,
|
|
hps.model.prior_n_heads,
|
|
hps.model.prior_n_layers,
|
|
hps.model.prior_kernel_size,
|
|
hps.model.prior_p_dropout,
|
|
n_speakers=hps.data.n_speakers,
|
|
spk_channels=hps.model.spk_channels
|
|
)
|
|
|
|
self.mel_decoder = Decoder(
|
|
hps.data.acoustic_dim,
|
|
hps.model.prior_hidden_channels,
|
|
hps.model.prior_filter_channels,
|
|
hps.model.prior_n_heads,
|
|
hps.model.prior_n_layers,
|
|
hps.model.prior_kernel_size,
|
|
hps.model.prior_p_dropout,
|
|
n_speakers=hps.data.n_speakers,
|
|
spk_channels=hps.model.spk_channels
|
|
)
|
|
|
|
self.posterior_encoder = PosteriorEncoder(
|
|
hps,
|
|
hps.data.acoustic_dim,
|
|
hps.model.hidden_channels,
|
|
hps.model.hidden_channels, 3, 1, 8)
|
|
|
|
self.dropout = nn.Dropout(0.2)
|
|
|
|
self.LR = LengthRegulator()
|
|
|
|
self.dec = Generator(hps,
|
|
hps.model.hidden_channels,
|
|
hps.model.resblock,
|
|
hps.model.resblock_kernel_sizes,
|
|
hps.model.resblock_dilation_sizes,
|
|
hps.model.upsample_rates,
|
|
hps.model.upsample_initial_channel,
|
|
hps.model.upsample_kernel_sizes,
|
|
n_speakers=hps.data.n_speakers,
|
|
spk_channels=hps.model.spk_channels)
|
|
|
|
self.dec_harm = Generator_Harm(hps)
|
|
|
|
self.dec_noise = Generator_Noise(hps)
|
|
|
|
self.f0_prenet = nn.Conv1d(1, hps.model.prior_hidden_channels , 3, padding=1)
|
|
self.energy_prenet = nn.Conv1d(1, hps.model.prior_hidden_channels , 3, padding=1)
|
|
self.mel_prenet = nn.Conv1d(hps.data.acoustic_dim, hps.model.prior_hidden_channels , 3, padding=1)
|
|
|
|
if hps.data.n_speakers > 1:
|
|
self.emb_spk = nn.Embedding(hps.data.n_speakers, hps.model.spk_channels)
|
|
self.flow = modules.ResidualCouplingBlock(hps.model.prior_hidden_channels, hps.model.hidden_channels, 5, 1, 4,n_speakers=hps.data.n_speakers, gin_channels=hps.model.spk_channels)
|
|
|
|
def forward(self, c, f0, mel2ph, t_window, noise=None, g=None):
|
|
if len(g.shape) == 2:
|
|
g = g.squeeze(0)
|
|
if len(f0.shape) == 2:
|
|
f0 = f0.unsqueeze(0)
|
|
g = self.emb_spk(g).unsqueeze(-1) # [b, h, 1]
|
|
|
|
decoder_inp = F.pad(c, [0, 0, 1, 0])
|
|
mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]])
|
|
c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H]
|
|
|
|
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
|
|
|
# Encoder
|
|
decoder_input, x_mask = self.text_encoder(c, c_lengths)
|
|
y_lengths = c_lengths
|
|
|
|
LF0 = 2595. * torch.log10(1. + f0 / 700.)
|
|
LF0 = LF0 / 500
|
|
|
|
# aam
|
|
predict_mel, predict_bn_mask = self.mel_decoder(decoder_input + self.f0_prenet(LF0), y_lengths, spk_emb=g)
|
|
predict_energy = predict_mel.sum(1).unsqueeze(1) / self.hps.data.acoustic_dim
|
|
|
|
decoder_input = decoder_input + \
|
|
self.f0_prenet(LF0) + \
|
|
self.energy_prenet(predict_energy) + \
|
|
self.mel_prenet(predict_mel)
|
|
decoder_output, y_mask = self.decoder(decoder_input, y_lengths, spk_emb=g)
|
|
|
|
prior_info = decoder_output
|
|
|
|
m_p = prior_info[:, :self.hps.model.hidden_channels, :]
|
|
logs_p = prior_info[:, self.hps.model.hidden_channels:, :]
|
|
z_p = m_p + torch.exp(logs_p) * noise
|
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
|
|
|
prior_z = z
|
|
|
|
noise_x = self.dec_noise(prior_z, y_mask, t_window)
|
|
|
|
harm_x = self.dec_harm(f0, prior_z, y_mask)
|
|
|
|
pitch = upsample(f0.transpose(1, 2), self.hps.data.hop_length)
|
|
omega = torch.cumsum(2 * math.pi * pitch / self.hps.data.sampling_rate, 1)
|
|
sin = torch.sin(omega).transpose(1, 2)
|
|
|
|
decoder_condition = torch.cat([harm_x, noise_x, sin], axis=1)
|
|
|
|
# dsp based HiFiGAN vocoder
|
|
o = self.dec(prior_z, decoder_condition, g=g)
|
|
|
|
return o
|