From fe1f733aff767902b14f08a54bcf66ecb143f41d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Fri, 10 Mar 2023 19:43:07 +0800 Subject: [PATCH] Add files via upload --- configs/config.json | 113 ++-- inference/infer_tool.py | 12 +- modules/audio.py | 99 ++++ modules/ddsp.py | 189 +++++++ modules/modules.py | 535 ++++++++++-------- modules/stft.py | 512 ++++++++++++++++++ modules/transforms.py | 193 +++++++ onnxexport/model_onnx.py | 1106 +++++++++++++++++++++++++++++++------- 8 files changed, 2305 insertions(+), 454 deletions(-) create mode 100644 modules/audio.py create mode 100644 modules/ddsp.py create mode 100644 modules/stft.py create mode 100644 modules/transforms.py diff --git a/configs/config.json b/configs/config.json index f19d46d..550059a 100644 --- a/configs/config.json +++ b/configs/config.json @@ -1,62 +1,103 @@ { "train": { - "log_interval": 200, - "eval_interval": 800, + "log_interval": 50, + "eval_interval": 1000, "seed": 1234, + "port": 8001, "epochs": 10000, - "learning_rate": 0.0001, + "learning_rate": 0.0002, "betas": [ 0.8, 0.99 ], "eps": 1e-09, "batch_size": 6, + "accumulation_steps": 1, "fp16_run": false, - "lr_decay": 0.999875, + "lr_decay": 0.998, "segment_size": 10240, "init_lr_ratio": 1, "warmup_epochs": 0, "c_mel": 45, - "c_kl": 1.0, - "use_sr": true, - "max_speclen": 512, - "port": "8001", - "keep_ckpts": 3 + "keep_ckpts":4 }, "data": { - "training_files": "filelists/train.txt", - "validation_files": "filelists/val.txt", + "data_dir": "dataset", + "dataset_type": "SingDataset", + "collate_type": "SingCollate", + "training_filelist": "filelists/train.txt", + "validation_filelist": "filelists/val.txt", "max_wav_value": 32768.0, "sampling_rate": 44100, - "filter_length": 2048, + "n_fft": 2048, + "fmin": 0, + "fmax": 22050, "hop_length": 512, - "win_length": 2048, - "n_mel_channels": 80, - "mel_fmin": 0.0, - "mel_fmax": 22050 - }, - "model": { - "inter_channels": 192, - "hidden_channels": 192, - "filter_channels": 768, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": 0.1, - "resblock": "1", - "resblock_kernel_sizes": [3,7,11], - "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], - "upsample_rates": [ 8, 8, 2, 2, 2], - "upsample_initial_channel": 512, - "upsample_kernel_sizes": [16,16, 4, 4, 4], - "n_layers_q": 3, - "use_spectral_norm": false, - "gin_channels": 256, - "ssl_dim": 256, + "win_size": 2048, + "acoustic_dim": 80, + "c_dim": 256, + "min_level_db": -115, + "ref_level_db": 20, + "min_db": -115, + "max_abs_value": 4.0, "n_speakers": 200 }, + "model": { + "hidden_channels": 192, + "spk_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 4, + "kernel_size": 3, + "p_dropout": 0.1, + "prior_hidden_channels": 192, + "prior_filter_channels": 768, + "prior_n_heads": 2, + "prior_n_layers": 4, + "prior_kernel_size": 3, + "prior_p_dropout": 0.1, + "resblock": "1", + "use_spectral_norm": false, + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 4, + 2 + ], + "upsample_initial_channel": 256, + "upsample_kernel_sizes": [ + 16, + 16, + 8, + 4 + ], + "n_harmonic": 64, + "n_bands": 65 + }, "spk": { - "nyaru": 0, + "jishuang": 0, "huiyu": 1, "nen": 2, "paimon": 3, diff --git a/inference/infer_tool.py b/inference/infer_tool.py index 9d22a25..03b97ad 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -127,9 +127,8 @@ class Svc(object): def load_model(self): # 获取模型配置 self.net_g_ms = SynthesizerTrn( - self.hps_ms.data.filter_length // 2 + 1, - self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, - **self.hps_ms.model) + self.hps_ms + ) _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None) if "half" in self.net_g_path and torch.cuda.is_available(): _ = self.net_g_ms.half().eval().to(self.dev) @@ -167,17 +166,14 @@ class Svc(object): cluster_infer_ratio=0, auto_predict_f0=False, noice_scale=0.4): - speaker_id = self.spk2id.__dict__.get(speaker) - if not speaker_id and type(speaker) is int: - if len(self.spk2id.__dict__) >= speaker: - speaker_id = speaker + speaker_id = self.spk2id[speaker] sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0) c, f0, uv = self.get_unit_f0(raw_path, tran, cluster_infer_ratio, speaker) if "half" in self.net_g_path and torch.cuda.is_available(): c = c.half() with torch.no_grad(): start = time.time() - audio = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale)[0,0].data.float() + audio = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale)[0][0,0].data.float() use_time = time.time() - start print("vits use time:{}".format(use_time)) return audio, audio.shape[-1] diff --git a/modules/audio.py b/modules/audio.py new file mode 100644 index 0000000..01b1f6a --- /dev/null +++ b/modules/audio.py @@ -0,0 +1,99 @@ +import numpy as np +from numpy import linalg as LA +import librosa +from scipy.io import wavfile +import soundfile as sf +import librosa.filters + + +def load_wav(wav_path, raw_sr, target_sr=16000, win_size=800, hop_size=200): + audio = librosa.core.load(wav_path, sr=raw_sr)[0] + if raw_sr != target_sr: + audio = librosa.core.resample(audio, + raw_sr, + target_sr, + res_type='kaiser_best') + target_length = (audio.size // hop_size + + win_size // hop_size) * hop_size + pad_len = (target_length - audio.size) // 2 + if audio.size % 2 == 0: + audio = np.pad(audio, (pad_len, pad_len), mode='reflect') + else: + audio = np.pad(audio, (pad_len, pad_len + 1), mode='reflect') + return audio + + +def save_wav(wav, path, sample_rate, norm=False): + if norm: + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + wavfile.write(path, sample_rate, wav.astype(np.int16)) + else: + sf.write(path, wav, sample_rate) + + +_mel_basis = None +_inv_mel_basis = None + + +def _build_mel_basis(hparams): + assert hparams.fmax <= hparams.sampling_rate // 2 + return librosa.filters.mel(hparams.sampling_rate, + hparams.n_fft, + n_mels=hparams.acoustic_dim, + fmin=hparams.fmin, + fmax=hparams.fmax) + + +def _linear_to_mel(spectogram, hparams): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis(hparams) + return np.dot(_mel_basis, spectogram) + + +def _mel_to_linear(mel_spectrogram, hparams): + global _inv_mel_basis + if _inv_mel_basis is None: + _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams)) + return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) + + +def _stft(y, hparams): + return librosa.stft(y=y, + n_fft=hparams.n_fft, + hop_length=hparams.hop_length, + win_length=hparams.win_size) + + +def _amp_to_db(x, hparams): + min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _normalize(S, hparams): + return hparams.max_abs_value * np.clip(((S - hparams.min_db) / + (-hparams.min_db)), 0, 1) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + + +def _stft(y, hparams): + return librosa.stft(y=y, + n_fft=hparams.n_fft, + hop_length=hparams.hop_length, + win_length=hparams.win_size) + + +def _istft(y, hparams): + return librosa.istft(y, + hop_length=hparams.hop_length, + win_length=hparams.win_size) + + +def melspectrogram(wav, hparams): + D = _stft(wav, hparams) + S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), + hparams) - hparams.ref_level_db + return _normalize(S, hparams) + + diff --git a/modules/ddsp.py b/modules/ddsp.py new file mode 100644 index 0000000..7ffd7bb --- /dev/null +++ b/modules/ddsp.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import torch.fft as fft +import numpy as np +import librosa as li +import math +from scipy.signal import get_window + +def safe_log(x): + return torch.log(x + 1e-7) + + +@torch.no_grad() +def mean_std_loudness(dataset): + mean = 0 + std = 0 + n = 0 + for _, _, l in dataset: + n += 1 + mean += (l.mean().item() - mean) / n + std += (l.std().item() - std) / n + return mean, std + + +def multiscale_fft(signal, scales, overlap): + stfts = [] + for s in scales: + S = torch.stft( + signal, + s, + int(s * (1 - overlap)), + s, + torch.hann_window(s).to(signal), + True, + normalized=True, + return_complex=True, + ).abs() + stfts.append(S) + return stfts + + +def resample(x, factor: int): + batch, frame, channel = x.shape + x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame) + + window = torch.hann_window( + factor * 2, + dtype=x.dtype, + device=x.device, + ).reshape(1, 1, -1) + y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x) + y[..., ::factor] = x + y[..., -1:] = x[..., -1:] + y = torch.nn.functional.pad(y, [factor, factor]) + y = torch.nn.functional.conv1d(y, window)[..., :-1] + + y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1) + + return y + + +def upsample(signal, factor): + signal = signal.permute(0, 2, 1) + signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor) + return signal.permute(0, 2, 1) + + +def remove_above_nyquist(amplitudes, pitch, sampling_rate): + n_harm = amplitudes.shape[-1] + pitches = pitch * torch.arange(1, n_harm + 1).to(pitch) + aa = (pitches < sampling_rate / 2).float() + 1e-4 + return amplitudes * aa + + +def scale_function(x): + return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7 + + +def extract_loudness(signal, sampling_rate, block_size, n_fft=2048): + S = li.stft( + signal, + n_fft=n_fft, + hop_length=block_size, + win_length=n_fft, + center=True, + ) + S = np.log(abs(S) + 1e-7) + f = li.fft_frequencies(sampling_rate, n_fft) + a_weight = li.A_weighting(f) + + S = S + a_weight.reshape(-1, 1) + + S = np.mean(S, 0)[..., :-1] + + return S + + +def extract_pitch(signal, sampling_rate, block_size): + length = signal.shape[-1] // block_size + f0 = crepe.predict( + signal, + sampling_rate, + step_size=int(1000 * block_size / sampling_rate), + verbose=1, + center=True, + viterbi=True, + ) + f0 = f0[1].reshape(-1)[:-1] + + if f0.shape[-1] != length: + f0 = np.interp( + np.linspace(0, 1, length, endpoint=False), + np.linspace(0, 1, f0.shape[-1], endpoint=False), + f0, + ) + + return f0 + + +def mlp(in_size, hidden_size, n_layers): + channels = [in_size] + (n_layers) * [hidden_size] + net = [] + for i in range(n_layers): + net.append(nn.Linear(channels[i], channels[i + 1])) + net.append(nn.LayerNorm(channels[i + 1])) + net.append(nn.LeakyReLU()) + return nn.Sequential(*net) + + +def gru(n_input, hidden_size): + return nn.GRU(n_input * hidden_size, hidden_size, batch_first=True) + + +def harmonic_synth(pitch, amplitudes, sampling_rate): + n_harmonic = amplitudes.shape[-1] + omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1) + omegas = omega * torch.arange(1, n_harmonic + 1).to(omega) + signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True) + return signal + + +def amp_to_impulse_response(amp, target_size): + amp = torch.stack([amp, torch.zeros_like(amp)], -1) + amp = torch.view_as_complex(amp) + amp = fft.irfft(amp) + + filter_size = amp.shape[-1] + + amp = torch.roll(amp, filter_size // 2, -1) + win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device) + + amp = amp * win + + amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size))) + amp = torch.roll(amp, -filter_size // 2, -1) + + return amp + + +def fft_convolve(signal, kernel): + signal = nn.functional.pad(signal, (0, signal.shape[-1])) + kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) + + output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) + output = output[..., output.shape[-1] // 2:] + + return output + + +def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): + if win_type == 'None' or win_type is None: + window = np.ones(win_len) + else: + window = get_window(win_type, win_len, fftbins=True)#**0.5 + + N = fft_len + fourier_basis = np.fft.rfft(np.eye(N))[:win_len] + real_kernel = np.real(fourier_basis) + imag_kernel = np.imag(fourier_basis) + kernel = np.concatenate([real_kernel, imag_kernel], 1).T + + if invers : + kernel = np.linalg.pinv(kernel).T + + kernel = kernel*window + kernel = kernel[:, None, :] + return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) + diff --git a/modules/modules.py b/modules/modules.py index 54290fd..6dcba59 100644 --- a/modules/modules.py +++ b/modules/modules.py @@ -5,182 +5,187 @@ import scipy import torch from torch import nn from torch.nn import functional as F +from torch.autograd import Function +from typing import Any, Optional, Tuple from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm import modules.commons as commons from modules.commons import init_weights, get_padding - +from modules.transforms import piecewise_rational_quadratic_transform LRELU_SLOPE = 0.1 class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - class ConvReluNorm(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): - super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - assert n_layers > 1, "Number of layers should be larger than 0." + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." - self.conv_layers = nn.ModuleList() - self.norm_layers = nn.ModuleList() - self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = nn.Sequential( - nn.ReLU(), - nn.Dropout(p_dropout)) - for _ in range(n_layers-1): - self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() - def forward(self, x, x_mask): - x_org = x - for i in range(self.n_layers): - x = self.conv_layers[i](x * x_mask) - x = self.norm_layers[i](x) - x = self.relu_drop(x) - x = x_org + self.proj(x) - return x * x_mask + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout + """ + Dialted and Depth-Separable Convolution + """ - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size ** i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, - groups=channels, dilation=dilation, padding=padding - )) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, + groups=channels, dilation=dilation, padding=padding + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask class WN(torch.nn.Module): - def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): - super(WN, self).__init__() - assert(kernel_size % 2 == 1) - self.hidden_channels =hidden_channels - self.kernel_size = kernel_size, - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - self.p_dropout = p_dropout + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, n_speakers=0, spk_channels=0, + p_dropout=0): + super(WN, self).__init__() + assert (kernel_size % 2 == 1) + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_speakers = n_speakers + self.spk_channels = spk_channels + self.p_dropout = p_dropout - self.in_layers = torch.nn.ModuleList() - self.res_skip_layers = torch.nn.ModuleList() - self.drop = nn.Dropout(p_dropout) + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) - if gin_channels != 0: - cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + if n_speakers > 0: + cond_layer = torch.nn.Conv1d(spk_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') - for i in range(n_layers): - dilation = dilation_rate ** i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, - dilation=dilation, padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') - self.in_layers.append(in_layer) + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) - # last one is not necessary - if i < n_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels - res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') - self.res_skip_layers.append(res_skip_layer) + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) - def forward(self, x, x_mask, g=None, **kwargs): - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) - if g is not None: - g = self.cond_layer(g) + if g is not None: + g = self.cond_layer(g) - for i in range(self.n_layers): - x_in = self.in_layers[i](x) - if g is not None: - cond_offset = i * 2 * self.hidden_channels - g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] - else: - g_l = torch.zeros_like(x_in) + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) - acts = commons.fused_add_tanh_sigmoid_multiply( - x_in, - g_l, - n_channels_tensor) - acts = self.drop(acts) + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor) + acts = self.drop(acts) - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.n_layers - 1: - res_acts = res_skip_acts[:,:self.hidden_channels,:] - x = (x + res_acts) * x_mask - output = output + res_skip_acts[:,self.hidden_channels:,:] - else: - output = output + res_skip_acts - return output * x_mask + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, :self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask - def remove_weight_norm(self): - if self.gin_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) - for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) - for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) + def remove_weight_norm(self): + if self.n_speakers > 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) class ResBlock1(torch.nn.Module): @@ -256,87 +261,193 @@ class ResBlock2(torch.nn.Module): class Log(nn.Module): - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + class Flip(nn.Module): - def forward(self, x, *args, reverse=False, **kwargs): - x = torch.flip(x, [1]) - if not reverse: - logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) - return x, logdet - else: - return x + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels,1)) - self.logs = nn.Parameter(torch.zeros(channels,1)) + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1,2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x class ResidualCouplingLayer(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=0, - gin_channels=0, - mean_only=False): - assert channels % 2 == 0, "channels should be divisible by 2" - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.half_channels = channels // 2 - self.mean_only = mean_only + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + n_speakers=0, + spk_channels=0, + mean_only=False): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, n_speakers=n_speakers, + spk_channels=spk_channels) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels]*2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels]*2, 1) - else: - m = stats - logs = torch.zeros_like(m) + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1,2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + n_speakers=0, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, + n_speakers=n_speakers, spk_channels=gin_channels, mean_only=True)) + self.flows.append(Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class ConvFlow(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) + self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins:] + + x1, logabsdet = piecewise_rational_quadratic_transform(x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails='linear', + tail_bound=self.tail_bound + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class ResStack(nn.Module): + def __init__(self, channel, kernel_size=3, base=3, nums=4): + super(ResStack, self).__init__() + + self.layers = nn.ModuleList([ + nn.Sequential( + nn.LeakyReLU(), + nn.utils.weight_norm(nn.Conv1d(channel, channel, + kernel_size=kernel_size, dilation=base ** i, padding=base ** i)), + nn.LeakyReLU(), + nn.utils.weight_norm(nn.Conv1d(channel, channel, + kernel_size=kernel_size, dilation=1, padding=1)), + ) + for i in range(nums) + ]) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x diff --git a/modules/stft.py b/modules/stft.py new file mode 100644 index 0000000..6882506 --- /dev/null +++ b/modules/stft.py @@ -0,0 +1,512 @@ +from librosa.util import pad_center, tiny +from scipy.signal import get_window +from torch import Tensor +from torch.autograd import Variable +from typing import Optional, Tuple + +import librosa +import librosa.util as librosa_util +import math +import numpy as np +import scipy +import torch +import torch.nn.functional as F +import warnings + + +def create_fb_matrix( + n_freqs: int, + f_min: float, + f_max: float, + n_mels: int, + sample_rate: int, + norm: Optional[str] = None +) -> Tensor: + r"""Create a frequency bin conversion matrix. + + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency (Hz) + f_max (float): Maximum frequency (Hz) + n_mels (int): Number of mel filterbanks + sample_rate (int): Sample rate of the audio waveform + norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). (Default: ``None``) + + Returns: + Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., ``n_freqs``), the applied result would be + ``A * create_fb_matrix(A.size(-1), ...)``. + """ + + if norm is not None and norm != "slaney": + raise ValueError("norm must be one of None or 'slaney'") + + # freq bins + # Equivalent filterbank construction by Librosa + all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) + + # calculate mel freq bins + # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) + m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0)) + m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) + m_pts = torch.linspace(m_min, m_max, n_mels + 2) + # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) + f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) + # calculate the difference between each mel point and each stft freq point in hertz + f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) + slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2) + # create overlapping triangles + down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) + up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) + fb = torch.min(down_slopes, up_slopes) + fb = torch.clamp(fb, 1e-6, 1) + + if norm is not None and norm == "slaney": + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels]) + fb *= enorm.unsqueeze(0) + return fb + + +def lfilter( + waveform: Tensor, + a_coeffs: Tensor, + b_coeffs: Tensor, + clamp: bool = True, +) -> Tensor: + r"""Perform an IIR filter by evaluating difference equation. + + Args: + waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. + a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``. + Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. + Must be same size as b_coeffs (pad with 0's as necessary). + b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``. + Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. + Must be same size as a_coeffs (pad with 0's as necessary). + clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) + + Returns: + Tensor: Waveform with dimension of ``(..., time)``. + """ + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + + assert (a_coeffs.size(0) == b_coeffs.size(0)) + assert (len(waveform.size()) == 2) + assert (waveform.device == a_coeffs.device) + assert (b_coeffs.device == a_coeffs.device) + + device = waveform.device + dtype = waveform.dtype + n_channel, n_sample = waveform.size() + n_order = a_coeffs.size(0) + n_sample_padded = n_sample + n_order - 1 + assert (n_order > 0) + + # Pad the input and create output + padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device) + padded_waveform[:, (n_order - 1):] = waveform + padded_output_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device) + + # Set up the coefficients matrix + # Flip coefficients' order + a_coeffs_flipped = a_coeffs.flip(0) + b_coeffs_flipped = b_coeffs.flip(0) + + # calculate windowed_input_signal in parallel + # create indices of original with shape (n_channel, n_order, n_sample) + window_idxs = torch.arange(n_sample, device=device).unsqueeze(0) + torch.arange(n_order, device=device).unsqueeze(1) + window_idxs = window_idxs.repeat(n_channel, 1, 1) + window_idxs += (torch.arange(n_channel, device=device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded) + window_idxs = window_idxs.long() + # (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample) + input_signal_windows = torch.matmul(b_coeffs_flipped, torch.take(padded_waveform, window_idxs)) + + input_signal_windows.div_(a_coeffs[0]) + a_coeffs_flipped.div_(a_coeffs[0]) + for i_sample, o0 in enumerate(input_signal_windows.t()): + windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)] + o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1) + padded_output_waveform[:, i_sample + n_order - 1] = o0 + + output = padded_output_waveform[:, (n_order - 1):] + + if clamp: + output = torch.clamp(output, min=-1., max=1.) + + # unpack batch + output = output.reshape(shape[:-1] + output.shape[-1:]) + + return output + + + +def biquad( + waveform: Tensor, + b0: float, + b1: float, + b2: float, + a0: float, + a1: float, + a2: float +) -> Tensor: + r"""Perform a biquad filter of input tensor. Initial conditions set to 0. + https://en.wikipedia.org/wiki/Digital_biquad_filter + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + b0 (float): numerator coefficient of current input, x[n] + b1 (float): numerator coefficient of input one time step ago x[n-1] + b2 (float): numerator coefficient of input two time steps ago x[n-2] + a0 (float): denominator coefficient of current output y[n], typically 1 + a1 (float): denominator coefficient of current output y[n-1] + a2 (float): denominator coefficient of current output y[n-2] + + Returns: + Tensor: Waveform with dimension of `(..., time)` + """ + + device = waveform.device + dtype = waveform.dtype + + output_waveform = lfilter( + waveform, + torch.tensor([a0, a1, a2], dtype=dtype, device=device), + torch.tensor([b0, b1, b2], dtype=dtype, device=device) + ) + return output_waveform + + + +def _dB2Linear(x: float) -> float: + return math.exp(x * math.log(10) / 20.0) + + +def highpass_biquad( + waveform: Tensor, + sample_rate: int, + cutoff_freq: float, + Q: float = 0.707 +) -> Tensor: + r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + cutoff_freq (float): filter cutoff frequency + Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) + + Returns: + Tensor: Waveform dimension of `(..., time)` + """ + w0 = 2 * math.pi * cutoff_freq / sample_rate + alpha = math.sin(w0) / 2. / Q + + b0 = (1 + math.cos(w0)) / 2 + b1 = -1 - math.cos(w0) + b2 = b0 + a0 = 1 + alpha + a1 = -2 * math.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + + +def lowpass_biquad( + waveform: Tensor, + sample_rate: int, + cutoff_freq: float, + Q: float = 0.707 +) -> Tensor: + r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation. + + Args: + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + cutoff_freq (float): filter cutoff frequency + Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) + + Returns: + Tensor: Waveform of dimension of `(..., time)` + """ + w0 = 2 * math.pi * cutoff_freq / sample_rate + alpha = math.sin(w0) / 2 / Q + + b0 = (1 - math.cos(w0)) / 2 + b1 = 1 - math.cos(w0) + b2 = b0 + a0 = 1 + alpha + a1 = -2 * math.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm)**2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + + +class MelScale(torch.nn.Module): + r"""Turn a normal STFT into a mel frequency STFT, using a conversion + matrix. This uses triangular filter banks. + + User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). + + Args: + n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + f_min (float, optional): Minimum frequency. (Default: ``0.``) + f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) + n_stft (int, optional): Number of bins in STFT. Calculated from first input + if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``) + """ + __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] + + def __init__(self, + n_mels: int = 128, + sample_rate: int = 24000, + f_min: float = 0., + f_max: Optional[float] = None, + n_stft: Optional[int] = None) -> None: + super(MelScale, self).__init__() + self.n_mels = n_mels + self.sample_rate = sample_rate + self.f_max = f_max if f_max is not None else float(sample_rate // 2) + self.f_min = f_min + + assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) + + fb = torch.empty(0) if n_stft is None else create_fb_matrix( + n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) + self.register_buffer('fb', fb) + + def forward(self, specgram: Tensor) -> Tensor: + r""" + Args: + specgram (Tensor): A spectrogram STFT of dimension (..., freq, time). + + Returns: + Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). + """ + + # pack batch + shape = specgram.size() + specgram = specgram.reshape(-1, shape[-2], shape[-1]) + + if self.fb.numel() == 0: + tmp_fb = create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate) + # Attributes cannot be reassigned outside __init__ so workaround + self.fb.resize_(tmp_fb.size()) + self.fb.copy_(tmp_fb) + + # (channel, frequency, time).transpose(...) dot (frequency, n_mels) + # -> (channel, time, n_mels).transpose(...) + mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) + + # unpack batch + mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:]) + + return mel_specgram + + +class TorchSTFT(torch.nn.Module): + def __init__(self, fft_size, hop_size, win_size, + normalized=False, domain='linear', + mel_scale=False, ref_level_db=20, min_level_db=-100): + super().__init__() + self.fft_size = fft_size + self.hop_size = hop_size + self.win_size = win_size + self.ref_level_db = ref_level_db + self.min_level_db = min_level_db + self.window = torch.hann_window(win_size) + self.normalized = normalized + self.domain = domain + self.mel_scale = MelScale(n_mels=(fft_size // 2 + 1), + n_stft=(fft_size // 2 + 1)) if mel_scale else None + + def transform(self, x): + x_stft = torch.stft(x.to(torch.float32), self.fft_size, self.hop_size, self.win_size, + self.window.type_as(x), normalized=self.normalized) + real = x_stft[..., 0] + imag = x_stft[..., 1] + mag = torch.clamp(real ** 2 + imag ** 2, min=1e-7) + mag = torch.sqrt(mag) + phase = torch.atan2(imag, real) + + if self.mel_scale is not None: + mag = self.mel_scale(mag) + + if self.domain == 'log': + mag = 20 * torch.log10(mag) - self.ref_level_db + mag = torch.clamp((mag - self.min_level_db) / -self.min_level_db, 0, 1) + return mag, phase + elif self.domain == 'linear': + return mag, phase + elif self.domain == 'double': + log_mag = 20 * torch.log10(mag) - self.ref_level_db + log_mag = torch.clamp((log_mag - self.min_level_db) / -self.min_level_db, 0, 1) + return torch.cat((mag, log_mag), dim=1), phase + + def complex(self, x): + x_stft = torch.stft(x, self.fft_size, self.hop_size, self.win_size, + self.window.type_as(x), normalized=self.normalized) + real = x_stft[..., 0] + imag = x_stft[..., 1] + return real, imag + + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + def __init__(self, filter_length=800, hop_length=200, win_length=800, + window='hann'): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])]) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + if window is not None: + assert(filter_length >= win_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode='reflect') + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable( + torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False) + window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + diff --git a/modules/transforms.py b/modules/transforms.py new file mode 100644 index 0000000..4793d67 --- /dev/null +++ b/modules/transforms.py @@ -0,0 +1,193 @@ +import torch +from torch.nn import functional as F + +import numpy as np + + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = { + 'tails': tails, + 'tail_bound': tail_bound + } + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum( + inputs[..., None] >= bin_locations, + dim=-1 + ) - 1 + + +def unconstrained_rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails='linear', + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == 'linear': + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError('{} tails are not implemented.'.format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative + ) + + return outputs, logabsdet + +def rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0., right=1., bottom=0., top=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError('Input to a transform is not within its domain') + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError('Minimal bin width too large for the number of bins') + if min_bin_height * num_bins > 1.0: + raise ValueError('Minimal bin height too large for the number of bins') + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (((inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta) + + input_heights * (input_delta - input_derivatives))) + b = (input_heights * input_derivatives + - (inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta)) + c = - input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/onnxexport/model_onnx.py b/onnxexport/model_onnx.py index e28bae9..b50406c 100644 --- a/onnxexport/model_onnx.py +++ b/onnxexport/model_onnx.py @@ -1,64 +1,379 @@ +import sys +import copy +import math import torch from torch import nn from torch.nn import functional as F - -import modules.attentions as attentions -import modules.commons as commons -import modules.modules as modules - from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +import numpy as np + +sys.path.append('../..') +import modules.commons as commons +import modules.modules as modules +import modules.attentions as attentions -import utils from modules.commons import init_weights, get_padding -from vdecoder.hifigan.models import Generator -from utils import f0_to_coarse + +from modules.ddsp import mlp, gru, scale_function, remove_above_nyquist, upsample +from modules.ddsp import harmonic_synth, amp_to_impulse_response, fft_convolve +from modules.ddsp import resample +import utils + +from modules.stft import TorchSTFT + +import torch.distributions as D + +from modules.losses import ( + generator_loss, + discriminator_loss, + feature_loss, + kl_loss +) + +LRELU_SLOPE = 0.1 -class ResidualCouplingBlock(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0): +class PostF0Decoder(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, spk_channels=0): super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels + + self.in_channels = in_channels + self.filter_channels = filter_channels self.kernel_size = kernel_size - self.dilation_rate = dilation_rate + self.p_dropout = p_dropout + self.gin_channels = spk_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if spk_channels != 0: + self.cond = nn.Conv1d(spk_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + c_dim, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append( - modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, - gin_channels=gin_channels, mean_only=True)) - self.flows.append(modules.Flip()) + self.pre_net = torch.nn.Linear(c_dim, hidden_channels) - def forward(self, x, x_mask, g=None, reverse=False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x, x_lengths): + x = x.transpose(1,-1) + x = self.pre_net(x) + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.encoder(x * x_mask, x_mask) + x = self.proj(x) * x_mask + return x, x_mask + + +def pad_v2(input_ele, mel_max_length=None): + if mel_max_length: + max_len = mel_max_length + else: + max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) + + out_list = list() + for i, batch in enumerate(input_ele): + if len(batch.shape) == 1: + one_batch_padded = F.pad( + batch, (0, max_len - batch.size(0)), "constant", 0.0 + ) + elif len(batch.shape) == 2: + one_batch_padded = F.pad( + batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 + ) + out_list.append(one_batch_padded) + out_padded = torch.stack(out_list) + return out_padded + + +class LengthRegulator(nn.Module): + """ Length Regulator """ + + def __init__(self): + super(LengthRegulator, self).__init__() + + def LR(self, x, duration, max_len): + x = torch.transpose(x, 1, 2) + output = list() + mel_len = list() + for batch, expand_target in zip(x, duration): + expanded = self.expand(batch, expand_target) + output.append(expanded) + mel_len.append(expanded.shape[0]) + + if max_len is not None: + output = pad_v2(output, max_len) else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, reverse=reverse) + output = pad_v2(output) + output = torch.transpose(output, 1, 2) + return output, torch.LongTensor(mel_len) + + def expand(self, batch, predicted): + predicted = torch.squeeze(predicted) + out = list() + + for i, vec in enumerate(batch): + expand_size = predicted[i].item() + state_info_index = torch.unsqueeze(torch.arange(0, expand_size), 1).float() + state_info_length = torch.unsqueeze(torch.Tensor([expand_size] * expand_size), 1).float() + state_info = torch.cat([state_info_index, state_info_length], 1).to(vec.device) + new_vec = vec.expand(max(int(expand_size), 0), -1) + new_vec = torch.cat([new_vec, state_info], 1) + out.append(new_vec) + out = torch.cat(out, 0) + return out + + def forward(self, x, duration, max_len): + output, mel_len = self.LR(x, duration, max_len) + return output, mel_len + + +class PriorDecoder(nn.Module): + def __init__(self, + out_bn_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_speakers=0, + spk_channels=0): + super().__init__() + self.out_bn_channels = out_bn_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(hidden_channels , hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_bn_channels, 1) + + if n_speakers != 0: + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, x_lengths, spk_emb=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x) * x_mask + + if (spk_emb is not None): + x = x + self.cond(spk_emb) + + x = self.decoder(x * x_mask, x_mask) + + bn = self.proj(x) * x_mask + + return bn, x_mask + + +class Decoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_speakers=0, + spk_channels=0, + in_channels=None): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(in_channels if in_channels is not None else hidden_channels, hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + if n_speakers != 0: + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, x_lengths, spk_emb=None): + x = torch.detach(x) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x) * x_mask + + if (spk_emb is not None): + x = x + self.cond(spk_emb) + + x = self.decoder(x * x_mask, x_mask) + + x = self.proj(x) * x_mask + + return x, x_mask + +class F0Decoder(nn.Module): + def __init__(self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_speakers=0, + spk_channels=0, + in_channels=None): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.spk_channels = spk_channels + + self.prenet = nn.Conv1d(in_channels if in_channels is not None else hidden_channels, hidden_channels, 3, padding=1) + self.decoder = attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.f0_prenet = nn.Conv1d(1, hidden_channels , 3, padding=1) + + if n_speakers != 0: + self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + + def forward(self, x, norm_f0, x_lengths, spk_emb=None): + x = torch.detach(x) + x += self.f0_prenet(norm_f0) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x) * x_mask + + if (spk_emb is not None): + x = x + self.cond(spk_emb) + + x = self.decoder(x * x_mask, x_mask) + + x = self.proj(x) * x_mask + + return x, x_mask + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x): + x = self.conv_layers[0](x) + x = self.norm_layers[0](x) + x = self.relu_drop(x) + + for i in range(1, self.n_layers): + x_ = self.conv_layers[i](x) + x_ = self.norm_layers[i](x_) + x_ = self.relu_drop(x_) + x = (x + x_) / 2 + x = self.proj(x) return x -class Encoder(nn.Module): +class PosteriorEncoder(nn.Module): def __init__(self, + hps, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, - n_layers, - gin_channels=0): + n_layers): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -66,57 +381,320 @@ class Encoder(nn.Module): self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.n_layers = n_layers - self.gin_channels = gin_channels self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, n_speakers=hps.data.n_speakers, spk_channels=hps.model.spk_channels) + # self.enc = ConvReluNorm(hidden_channels, + # hidden_channels, + # hidden_channels, + # kernel_size, + # n_layers, + # 0.1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, g=None): - # print(x.shape,x_lengths.shape) x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask + return stats, x_mask -class TextEncoder(nn.Module): - def __init__(self, - out_channels, - hidden_channels, - kernel_size, - n_layers, - gin_channels=0, - filter_channels=None, - n_heads=None, - p_dropout=None): +class ResBlock3(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock3, self).__init__() + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator_Harm(torch.nn.Module): + def __init__(self, hps): + super(Generator_Harm, self).__init__() + self.hps = hps + + self.prenet = Conv1d(hps.model.hidden_channels, hps.model.hidden_channels, 3, padding=1) + + self.net = ConvReluNorm(hps.model.hidden_channels, + hps.model.hidden_channels, + hps.model.hidden_channels, + hps.model.kernel_size, + 8, + hps.model.p_dropout) + + # self.rnn = nn.LSTM(input_size=hps.model.hidden_channels, + # hidden_size=hps.model.hidden_channels, + # num_layers=1, + # bias=True, + # batch_first=True, + # dropout=0.5, + # bidirectional=True) + self.postnet = Conv1d(hps.model.hidden_channels, hps.model.n_harmonic + 1, 3, padding=1) + + def forward(self, f0, harm, mask): + pitch = f0.transpose(1, 2) + harm = self.prenet(harm) + + harm = self.net(harm) * mask + # harm = harm.transpose(1, 2) + # harm, (hs, hc) = self.rnn(harm) + # harm = harm.transpose(1, 2) + + harm = self.postnet(harm) + harm = harm.transpose(1, 2) + param = harm + + param = scale_function(param) + total_amp = param[..., :1] + amplitudes = param[..., 1:] + amplitudes = remove_above_nyquist( + amplitudes, + pitch, + self.hps.data.sampling_rate, + ) + amplitudes /= amplitudes.sum(-1, keepdim=True) + amplitudes *= total_amp + + amplitudes = upsample(amplitudes, self.hps.data.hop_length) + pitch = upsample(pitch, self.hps.data.hop_length) + + n_harmonic = amplitudes.shape[-1] + omega = torch.cumsum(2 * math.pi * pitch / self.hps.data.sampling_rate, 1) + omegas = omega * torch.arange(1, n_harmonic + 1).to(omega) + signal_harmonics = (torch.sin(omegas) * amplitudes) + signal_harmonics = signal_harmonics.transpose(1, 2) + return signal_harmonics + + +class Generator(torch.nn.Module): + def __init__(self, hps, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, + upsample_initial_channel, upsample_kernel_sizes, n_speakers=0, spk_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + self.upsample_rates = upsample_rates + self.n_speakers = n_speakers + + resblock = modules.ResBlock1 if resblock == '1' else modules.R + + self.downs = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + i = len(upsample_rates) - 1 - i + u = upsample_rates[i] + k = upsample_kernel_sizes[i] + # print("down: ",upsample_initial_channel//(2**(i+1))," -> ", upsample_initial_channel//(2**i)) + self.downs.append(weight_norm( + Conv1d(hps.model.n_harmonic + 2, hps.model.n_harmonic + 2, + k, u, padding=k // 2))) + + self.resblocks_downs = nn.ModuleList() + for i in range(len(self.downs)): + j = len(upsample_rates) - 1 - i + self.resblocks_downs.append(ResBlock3(hps.model.n_harmonic + 2, 3, (1, 3))) + + self.concat_pre = Conv1d(upsample_initial_channel + hps.model.n_harmonic + 2, upsample_initial_channel, 3, 1, + padding=1) + self.concat_conv = nn.ModuleList() + for i in range(len(upsample_rates)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.concat_conv.append(Conv1d(ch + hps.model.n_harmonic + 2, ch, 3, 1, padding=1, bias=False)) + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if self.n_speakers != 0: + self.cond = nn.Conv1d(spk_channels, upsample_initial_channel, 1) + + def forward(self, x, ddsp, g=None): + + x = self.conv_pre(x) + + if g is not None: + x = x + self.cond(g) + + se = ddsp + res_features = [se] + for i in range(self.num_upsamples): + in_size = se.size(2) + se = self.downs[i](se) + se = self.resblocks_downs[i](se) + up_rate = self.upsample_rates[self.num_upsamples - 1 - i] + se = se[:, :, : in_size // up_rate] + res_features.append(se) + + x = torch.cat([x, se], 1) + x = self.concat_pre(x) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + in_size = x.size(2) + x = self.ups[i](x) + # 保证维度正确,丢掉多余通道 + x = x[:, :, : in_size * self.upsample_rates[i]] + + x = torch.cat([x, res_features[self.num_upsamples - 1 - i]], 1) + x = self.concat_conv[i](x) + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +from scipy.signal import get_window +def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): + if win_type == 'None' or win_type is None: + window = np.ones(win_len) + else: + window = get_window(win_type, win_len, fftbins=True)#**0.5 + + N = fft_len + fourier_basis = np.fft.rfft(np.eye(N))[:win_len] + real_kernel = np.real(fourier_basis) + imag_kernel = np.imag(fourier_basis) + kernel = np.concatenate([real_kernel, imag_kernel], 1).T + + if invers : + kernel = np.linalg.pinv(kernel).T + + kernel = kernel*window + kernel = kernel[:, None, :] + return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) + + +class ConviSTFT(nn.Module): + def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): + super(ConviSTFT, self).__init__() + if fft_len == None: + self.fft_len = np.int(2**np.ceil(np.log2(win_len))) + else: + self.fft_len = fft_len + kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) + #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) + self.register_buffer('weight', kernel) + self.feature_type = feature_type + self.win_type = win_type + self.win_len = win_len + self.stride = win_inc + self.dim = self.fft_len + self.register_buffer('window', window) + self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) + + def forward(self, inputs, t): + outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) + coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) + outputs = outputs/(coff+1e-8) + #outputs = torch.where(coff == 0, outputs, outputs/coff) + outputs = outputs[...,768:-768] + return outputs + + +class Generator_Noise(torch.nn.Module): + def __init__(self, hps): + super(Generator_Noise, self).__init__() + self.hps = hps + self.win_size = hps.data.win_size + self.hop_size = hps.data.hop_length + self.fft_size = hps.data.n_fft + self.istft_pre = Conv1d(hps.model.hidden_channels, hps.model.hidden_channels, 3, padding=1) + + self.net = ConvReluNorm(hps.model.hidden_channels, + hps.model.hidden_channels, + hps.model.hidden_channels, + hps.model.kernel_size, + 8, + hps.model.p_dropout) + + self.istft_amplitude = torch.nn.Conv1d(hps.model.hidden_channels, self.fft_size // 2 + 1, 1, 1) + self.window = torch.hann_window(self.win_size) + self.istft = ConviSTFT(self.win_size, self.hop_size ,self.fft_size) + + def forward(self, x, mask, t_window): + istft_x = x + istft_x = self.istft_pre(istft_x) + + istft_x = self.net(istft_x) * mask + + amp = self.istft_amplitude(istft_x).unsqueeze(-1) + phase = (torch.rand(amp.shape) * 2 * 3.14 - 3.14).to(amp) + + real = amp * torch.cos(phase) + imag = amp * torch.sin(phase) + + ''' + spec = torch.cat([real, imag], 1).squeeze(3) + print(spec.shape) + istft_x = self.istft(spec) + + spec = torch.cat([real, imag], 3) + istft_x = torch.istft(spec, self.fft_size, self.hop_size, self.win_size, self.window.to(amp), True, + length=x.shape[2] * self.hop_size, return_complex=False) + ''' + spec = torch.cat([real, imag], 1).squeeze(3) + istft_x = self.istft(spec, t_window) + + return istft_x + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): super().__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.gin_channels = gin_channels - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - self.f0_emb = nn.Embedding(256, hidden_channels) + self.channels = channels + self.eps = eps - self.enc_ = attentions.Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) - def forward(self, x, x_mask, f0=None, z=None): - x = x + self.f0_emb(f0).transpose(1, 2) - x = self.enc_(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + z * torch.exp(logs)) * x_mask - return z, m, logs, x_mask + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) class DiscriminatorP(torch.nn.Module): @@ -184,152 +762,284 @@ class DiscriminatorS(torch.nn.Module): return x, fmap -class F0Decoder(nn.Module): +class MultiFrequencyDiscriminator(nn.Module): def __init__(self, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - spk_channels=0): - super().__init__() - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.spk_channels = spk_channels + hop_lengths=[128, 256, 512], + hidden_channels=[256, 512, 512], + domain='double', mel_scale=True): + super(MultiFrequencyDiscriminator, self).__init__() - self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1) - self.decoder = attentions.FFT( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1) - self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) + self.stfts = nn.ModuleList([ + TorchSTFT(fft_size=x * 4, hop_size=x, win_size=x * 4, + normalized=True, domain=domain, mel_scale=mel_scale) + for x in hop_lengths]) - def forward(self, x, norm_f0, x_mask, spk_emb=None): - x = torch.detach(x) - if spk_emb is not None: - x = x + self.cond(spk_emb) - x += self.f0_prenet(norm_f0) - x = self.prenet(x) * x_mask - x = self.decoder(x * x_mask, x_mask) - x = self.proj(x) * x_mask - return x + self.domain = domain + if domain == 'double': + self.discriminators = nn.ModuleList([ + BaseFrequenceDiscriminator(2, c) + for x, c in zip(hop_lengths, hidden_channels)]) + else: + self.discriminators = nn.ModuleList([ + BaseFrequenceDiscriminator(1, c) + for x, c in zip(hop_lengths, hidden_channels)]) + + def forward(self, x): + scores, feats = list(), list() + for stft, layer in zip(self.stfts, self.discriminators): + # print(stft) + mag, phase = stft.transform(x.squeeze()) + if self.domain == 'double': + mag = torch.stack(torch.chunk(mag, 2, dim=1), dim=1) + else: + mag = mag.unsqueeze(1) + + score, feat = layer(mag) + scores.append(score) + feats.append(feat) + return scores, feats + + +class BaseFrequenceDiscriminator(nn.Module): + def __init__(self, in_channels, hidden_channels=512): + super(BaseFrequenceDiscriminator, self).__init__() + + self.discriminator = nn.ModuleList() + self.discriminator += [ + nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + in_channels, hidden_channels // 32, + kernel_size=(3, 3), stride=(1, 1))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 32, hidden_channels // 16, + kernel_size=(3, 3), stride=(2, 2))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 16, hidden_channels // 8, + kernel_size=(3, 3), stride=(1, 1))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 8, hidden_channels // 4, + kernel_size=(3, 3), stride=(2, 2))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 4, hidden_channels // 2, + kernel_size=(3, 3), stride=(1, 1))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels // 2, hidden_channels, + kernel_size=(3, 3), stride=(2, 2))) + ), + nn.Sequential( + nn.LeakyReLU(0.2, True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.utils.weight_norm(nn.Conv2d( + hidden_channels, 1, + kernel_size=(3, 3), stride=(1, 1))) + ) + ] + + def forward(self, x): + hiddens = [] + for layer in self.discriminator: + x = layer(x) + hiddens.append(x) + return x, hiddens[-1] + + +class Discriminator(torch.nn.Module): + def __init__(self, hps, use_spectral_norm=False): + super(Discriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + # self.disc_multfrequency = MultiFrequencyDiscriminator(hop_lengths=[int(hps.data.sampling_rate * 2.5 / 1000), + # int(hps.data.sampling_rate * 5 / 1000), + # int(hps.data.sampling_rate * 7.5 / 1000), + # int(hps.data.sampling_rate * 10 / 1000), + # int(hps.data.sampling_rate * 12.5 / 1000), + # int(hps.data.sampling_rate * 15 / 1000)], + # hidden_channels=[256, 256, 256, 256, 256]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + # scores_r, fmaps_r = self.disc_multfrequency(y) + # scores_g, fmaps_g = self.disc_multfrequency(y_hat) + # for i in range(len(scores_r)): + # y_d_rs.append(scores_r[i]) + # y_d_gs.append(scores_g[i]) + # fmap_rs.append(fmaps_r[i]) + # fmap_gs.append(fmaps_g[i]) + return y_d_rs, y_d_gs, fmap_rs, fmap_gs class SynthesizerTrn(nn.Module): """ - Synthesizer for Training - """ + Model + """ - def __init__(self, - spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels, - ssl_dim, - n_speakers, - sampling_rate=44100, - **kwargs): + def __init__(self, hps): super().__init__() - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.gin_channels = gin_channels - self.ssl_dim = ssl_dim - self.emb_g = nn.Embedding(n_speakers, gin_channels) + self.hps = hps - self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) + self.text_encoder = TextEncoder( + hps.data.c_dim, + hps.model.prior_hidden_channels, + hps.model.prior_hidden_channels, + hps.model.prior_filter_channels, + hps.model.prior_n_heads, + hps.model.prior_n_layers, + hps.model.prior_kernel_size, + hps.model.prior_p_dropout) - self.enc_p = TextEncoder( - inter_channels, - hidden_channels, - filter_channels=filter_channels, - n_heads=n_heads, - n_layers=n_layers, - kernel_size=kernel_size, - p_dropout=p_dropout + self.decoder = PriorDecoder( + hps.model.hidden_channels * 2, + hps.model.prior_hidden_channels, + hps.model.prior_filter_channels, + hps.model.prior_n_heads, + hps.model.prior_n_layers, + hps.model.prior_kernel_size, + hps.model.prior_p_dropout, + n_speakers=hps.data.n_speakers, + spk_channels=hps.model.spk_channels ) - hps = { - "sampling_rate": sampling_rate, - "inter_channels": inter_channels, - "resblock": resblock, - "resblock_kernel_sizes": resblock_kernel_sizes, - "resblock_dilation_sizes": resblock_dilation_sizes, - "upsample_rates": upsample_rates, - "upsample_initial_channel": upsample_initial_channel, - "upsample_kernel_sizes": upsample_kernel_sizes, - "gin_channels": gin_channels, - } - self.dec = Generator(h=hps) - self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + self.f0_decoder = F0Decoder( 1, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - spk_channels=gin_channels + hps.model.prior_hidden_channels, + hps.model.prior_filter_channels, + hps.model.prior_n_heads, + hps.model.prior_n_layers, + hps.model.prior_kernel_size, + hps.model.prior_p_dropout, + n_speakers=hps.data.n_speakers, + spk_channels=hps.model.spk_channels ) - self.emb_uv = nn.Embedding(2, hidden_channels) - self.predict_f0 = False - def forward(self, c, f0, mel2ph, uv, noise=None, g=None): + self.mel_decoder = Decoder( + hps.data.acoustic_dim, + hps.model.prior_hidden_channels, + hps.model.prior_filter_channels, + hps.model.prior_n_heads, + hps.model.prior_n_layers, + hps.model.prior_kernel_size, + hps.model.prior_p_dropout, + n_speakers=hps.data.n_speakers, + spk_channels=hps.model.spk_channels + ) + + self.posterior_encoder = PosteriorEncoder( + hps, + hps.data.acoustic_dim, + hps.model.hidden_channels, + hps.model.hidden_channels, 3, 1, 8) + + self.dropout = nn.Dropout(0.2) + + self.LR = LengthRegulator() + + self.dec = Generator(hps, + hps.model.hidden_channels, + hps.model.resblock, + hps.model.resblock_kernel_sizes, + hps.model.resblock_dilation_sizes, + hps.model.upsample_rates, + hps.model.upsample_initial_channel, + hps.model.upsample_kernel_sizes, + n_speakers=hps.data.n_speakers, + spk_channels=hps.model.spk_channels) + + self.dec_harm = Generator_Harm(hps) + + self.dec_noise = Generator_Noise(hps) + + self.f0_prenet = nn.Conv1d(1, hps.model.prior_hidden_channels , 3, padding=1) + self.energy_prenet = nn.Conv1d(1, hps.model.prior_hidden_channels , 3, padding=1) + self.mel_prenet = nn.Conv1d(hps.data.acoustic_dim, hps.model.prior_hidden_channels , 3, padding=1) + + if hps.data.n_speakers > 1: + self.emb_spk = nn.Embedding(hps.data.n_speakers, hps.model.spk_channels) + self.flow = modules.ResidualCouplingBlock(hps.model.prior_hidden_channels, hps.model.hidden_channels, 5, 1, 4,n_speakers=hps.data.n_speakers, gin_channels=hps.model.spk_channels) + + def forward(self, c, f0, mel2ph, t_window, noise=None, g=None): + if len(g.shape) == 2: + g = g.squeeze(0) + if len(f0.shape) == 2: + f0 = f0.unsqueeze(0) + g = self.emb_spk(g).unsqueeze(-1) # [b, h, 1] decoder_inp = F.pad(c, [0, 0, 1, 0]) mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]]) c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H] c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) - g = g.unsqueeze(0) - g = self.emb_g(g).transpose(1, 2) - x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) - x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) - if self.predict_f0: - lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 - norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False) - pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) - f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1) + # Encoder + decoder_input, x_mask = self.text_encoder(c, c_lengths) + y_lengths = c_lengths + + LF0 = 2595. * torch.log10(1. + f0 / 700.) + LF0 = LF0 / 500 + + # aam + predict_mel, predict_bn_mask = self.mel_decoder(decoder_input + self.f0_prenet(LF0), y_lengths, spk_emb=g) + predict_energy = predict_mel.sum(1).unsqueeze(1) / self.hps.data.acoustic_dim + + decoder_input = decoder_input + \ + self.f0_prenet(LF0) + \ + self.energy_prenet(predict_energy) + \ + self.mel_prenet(predict_mel) + decoder_output, y_mask = self.decoder(decoder_input, y_lengths, spk_emb=g) + + prior_info = decoder_output + + m_p = prior_info[:, :self.hps.model.hidden_channels, :] + logs_p = prior_info[:, self.hps.model.hidden_channels:, :] + z_p = m_p + torch.exp(logs_p) * noise + z = self.flow(z_p, y_mask, g=g, reverse=True) + + prior_z = z + + noise_x = self.dec_noise(prior_z, y_mask, t_window) + + harm_x = self.dec_harm(f0, prior_z, y_mask) + + pitch = upsample(f0.transpose(1, 2), self.hps.data.hop_length) + omega = torch.cumsum(2 * math.pi * pitch / self.hps.data.sampling_rate, 1) + sin = torch.sin(omega).transpose(1, 2) + + decoder_condition = torch.cat([harm_x, noise_x, sin], axis=1) + + # dsp based HiFiGAN vocoder + o = self.dec(prior_z, decoder_condition, g=g) - z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), z=noise) - z = self.flow(z_p, c_mask, g=g, reverse=True) - o = self.dec(z * c_mask, g=g, f0=f0) return o