Add files via upload
This commit is contained in:
parent
deadcc671b
commit
fe1f733aff
|
@ -1,62 +1,103 @@
|
|||
{
|
||||
"train": {
|
||||
"log_interval": 200,
|
||||
"eval_interval": 800,
|
||||
"log_interval": 50,
|
||||
"eval_interval": 1000,
|
||||
"seed": 1234,
|
||||
"port": 8001,
|
||||
"epochs": 10000,
|
||||
"learning_rate": 0.0001,
|
||||
"learning_rate": 0.0002,
|
||||
"betas": [
|
||||
0.8,
|
||||
0.99
|
||||
],
|
||||
"eps": 1e-09,
|
||||
"batch_size": 6,
|
||||
"accumulation_steps": 1,
|
||||
"fp16_run": false,
|
||||
"lr_decay": 0.999875,
|
||||
"lr_decay": 0.998,
|
||||
"segment_size": 10240,
|
||||
"init_lr_ratio": 1,
|
||||
"warmup_epochs": 0,
|
||||
"c_mel": 45,
|
||||
"c_kl": 1.0,
|
||||
"use_sr": true,
|
||||
"max_speclen": 512,
|
||||
"port": "8001",
|
||||
"keep_ckpts": 3
|
||||
"keep_ckpts":4
|
||||
},
|
||||
"data": {
|
||||
"training_files": "filelists/train.txt",
|
||||
"validation_files": "filelists/val.txt",
|
||||
"data_dir": "dataset",
|
||||
"dataset_type": "SingDataset",
|
||||
"collate_type": "SingCollate",
|
||||
"training_filelist": "filelists/train.txt",
|
||||
"validation_filelist": "filelists/val.txt",
|
||||
"max_wav_value": 32768.0,
|
||||
"sampling_rate": 44100,
|
||||
"filter_length": 2048,
|
||||
"n_fft": 2048,
|
||||
"fmin": 0,
|
||||
"fmax": 22050,
|
||||
"hop_length": 512,
|
||||
"win_length": 2048,
|
||||
"n_mel_channels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": 22050
|
||||
},
|
||||
"model": {
|
||||
"inter_channels": 192,
|
||||
"hidden_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": 0.1,
|
||||
"resblock": "1",
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
"upsample_rates": [ 8, 8, 2, 2, 2],
|
||||
"upsample_initial_channel": 512,
|
||||
"upsample_kernel_sizes": [16,16, 4, 4, 4],
|
||||
"n_layers_q": 3,
|
||||
"use_spectral_norm": false,
|
||||
"gin_channels": 256,
|
||||
"ssl_dim": 256,
|
||||
"win_size": 2048,
|
||||
"acoustic_dim": 80,
|
||||
"c_dim": 256,
|
||||
"min_level_db": -115,
|
||||
"ref_level_db": 20,
|
||||
"min_db": -115,
|
||||
"max_abs_value": 4.0,
|
||||
"n_speakers": 200
|
||||
},
|
||||
"model": {
|
||||
"hidden_channels": 192,
|
||||
"spk_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"n_heads": 2,
|
||||
"n_layers": 4,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": 0.1,
|
||||
"prior_hidden_channels": 192,
|
||||
"prior_filter_channels": 768,
|
||||
"prior_n_heads": 2,
|
||||
"prior_n_layers": 4,
|
||||
"prior_kernel_size": 3,
|
||||
"prior_p_dropout": 0.1,
|
||||
"resblock": "1",
|
||||
"use_spectral_norm": false,
|
||||
"resblock_kernel_sizes": [
|
||||
3,
|
||||
7,
|
||||
11
|
||||
],
|
||||
"resblock_dilation_sizes": [
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
]
|
||||
],
|
||||
"upsample_rates": [
|
||||
8,
|
||||
8,
|
||||
4,
|
||||
2
|
||||
],
|
||||
"upsample_initial_channel": 256,
|
||||
"upsample_kernel_sizes": [
|
||||
16,
|
||||
16,
|
||||
8,
|
||||
4
|
||||
],
|
||||
"n_harmonic": 64,
|
||||
"n_bands": 65
|
||||
},
|
||||
"spk": {
|
||||
"nyaru": 0,
|
||||
"jishuang": 0,
|
||||
"huiyu": 1,
|
||||
"nen": 2,
|
||||
"paimon": 3,
|
||||
|
|
|
@ -127,9 +127,8 @@ class Svc(object):
|
|||
def load_model(self):
|
||||
# 获取模型配置
|
||||
self.net_g_ms = SynthesizerTrn(
|
||||
self.hps_ms.data.filter_length // 2 + 1,
|
||||
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
|
||||
**self.hps_ms.model)
|
||||
self.hps_ms
|
||||
)
|
||||
_ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
|
||||
if "half" in self.net_g_path and torch.cuda.is_available():
|
||||
_ = self.net_g_ms.half().eval().to(self.dev)
|
||||
|
@ -167,17 +166,14 @@ class Svc(object):
|
|||
cluster_infer_ratio=0,
|
||||
auto_predict_f0=False,
|
||||
noice_scale=0.4):
|
||||
speaker_id = self.spk2id.__dict__.get(speaker)
|
||||
if not speaker_id and type(speaker) is int:
|
||||
if len(self.spk2id.__dict__) >= speaker:
|
||||
speaker_id = speaker
|
||||
speaker_id = self.spk2id[speaker]
|
||||
sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
|
||||
c, f0, uv = self.get_unit_f0(raw_path, tran, cluster_infer_ratio, speaker)
|
||||
if "half" in self.net_g_path and torch.cuda.is_available():
|
||||
c = c.half()
|
||||
with torch.no_grad():
|
||||
start = time.time()
|
||||
audio = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale)[0,0].data.float()
|
||||
audio = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale)[0][0,0].data.float()
|
||||
use_time = time.time() - start
|
||||
print("vits use time:{}".format(use_time))
|
||||
return audio, audio.shape[-1]
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
import numpy as np
|
||||
from numpy import linalg as LA
|
||||
import librosa
|
||||
from scipy.io import wavfile
|
||||
import soundfile as sf
|
||||
import librosa.filters
|
||||
|
||||
|
||||
def load_wav(wav_path, raw_sr, target_sr=16000, win_size=800, hop_size=200):
|
||||
audio = librosa.core.load(wav_path, sr=raw_sr)[0]
|
||||
if raw_sr != target_sr:
|
||||
audio = librosa.core.resample(audio,
|
||||
raw_sr,
|
||||
target_sr,
|
||||
res_type='kaiser_best')
|
||||
target_length = (audio.size // hop_size +
|
||||
win_size // hop_size) * hop_size
|
||||
pad_len = (target_length - audio.size) // 2
|
||||
if audio.size % 2 == 0:
|
||||
audio = np.pad(audio, (pad_len, pad_len), mode='reflect')
|
||||
else:
|
||||
audio = np.pad(audio, (pad_len, pad_len + 1), mode='reflect')
|
||||
return audio
|
||||
|
||||
|
||||
def save_wav(wav, path, sample_rate, norm=False):
|
||||
if norm:
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
wavfile.write(path, sample_rate, wav.astype(np.int16))
|
||||
else:
|
||||
sf.write(path, wav, sample_rate)
|
||||
|
||||
|
||||
_mel_basis = None
|
||||
_inv_mel_basis = None
|
||||
|
||||
|
||||
def _build_mel_basis(hparams):
|
||||
assert hparams.fmax <= hparams.sampling_rate // 2
|
||||
return librosa.filters.mel(hparams.sampling_rate,
|
||||
hparams.n_fft,
|
||||
n_mels=hparams.acoustic_dim,
|
||||
fmin=hparams.fmin,
|
||||
fmax=hparams.fmax)
|
||||
|
||||
|
||||
def _linear_to_mel(spectogram, hparams):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis(hparams)
|
||||
return np.dot(_mel_basis, spectogram)
|
||||
|
||||
|
||||
def _mel_to_linear(mel_spectrogram, hparams):
|
||||
global _inv_mel_basis
|
||||
if _inv_mel_basis is None:
|
||||
_inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
|
||||
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
|
||||
|
||||
|
||||
def _stft(y, hparams):
|
||||
return librosa.stft(y=y,
|
||||
n_fft=hparams.n_fft,
|
||||
hop_length=hparams.hop_length,
|
||||
win_length=hparams.win_size)
|
||||
|
||||
|
||||
def _amp_to_db(x, hparams):
|
||||
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
|
||||
return 20 * np.log10(np.maximum(min_level, x))
|
||||
|
||||
def _normalize(S, hparams):
|
||||
return hparams.max_abs_value * np.clip(((S - hparams.min_db) /
|
||||
(-hparams.min_db)), 0, 1)
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, (x) * 0.05)
|
||||
|
||||
|
||||
def _stft(y, hparams):
|
||||
return librosa.stft(y=y,
|
||||
n_fft=hparams.n_fft,
|
||||
hop_length=hparams.hop_length,
|
||||
win_length=hparams.win_size)
|
||||
|
||||
|
||||
def _istft(y, hparams):
|
||||
return librosa.istft(y,
|
||||
hop_length=hparams.hop_length,
|
||||
win_length=hparams.win_size)
|
||||
|
||||
|
||||
def melspectrogram(wav, hparams):
|
||||
D = _stft(wav, hparams)
|
||||
S = _amp_to_db(_linear_to_mel(np.abs(D), hparams),
|
||||
hparams) - hparams.ref_level_db
|
||||
return _normalize(S, hparams)
|
||||
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torch.fft as fft
|
||||
import numpy as np
|
||||
import librosa as li
|
||||
import math
|
||||
from scipy.signal import get_window
|
||||
|
||||
def safe_log(x):
|
||||
return torch.log(x + 1e-7)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mean_std_loudness(dataset):
|
||||
mean = 0
|
||||
std = 0
|
||||
n = 0
|
||||
for _, _, l in dataset:
|
||||
n += 1
|
||||
mean += (l.mean().item() - mean) / n
|
||||
std += (l.std().item() - std) / n
|
||||
return mean, std
|
||||
|
||||
|
||||
def multiscale_fft(signal, scales, overlap):
|
||||
stfts = []
|
||||
for s in scales:
|
||||
S = torch.stft(
|
||||
signal,
|
||||
s,
|
||||
int(s * (1 - overlap)),
|
||||
s,
|
||||
torch.hann_window(s).to(signal),
|
||||
True,
|
||||
normalized=True,
|
||||
return_complex=True,
|
||||
).abs()
|
||||
stfts.append(S)
|
||||
return stfts
|
||||
|
||||
|
||||
def resample(x, factor: int):
|
||||
batch, frame, channel = x.shape
|
||||
x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame)
|
||||
|
||||
window = torch.hann_window(
|
||||
factor * 2,
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
).reshape(1, 1, -1)
|
||||
y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x)
|
||||
y[..., ::factor] = x
|
||||
y[..., -1:] = x[..., -1:]
|
||||
y = torch.nn.functional.pad(y, [factor, factor])
|
||||
y = torch.nn.functional.conv1d(y, window)[..., :-1]
|
||||
|
||||
y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def upsample(signal, factor):
|
||||
signal = signal.permute(0, 2, 1)
|
||||
signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor)
|
||||
return signal.permute(0, 2, 1)
|
||||
|
||||
|
||||
def remove_above_nyquist(amplitudes, pitch, sampling_rate):
|
||||
n_harm = amplitudes.shape[-1]
|
||||
pitches = pitch * torch.arange(1, n_harm + 1).to(pitch)
|
||||
aa = (pitches < sampling_rate / 2).float() + 1e-4
|
||||
return amplitudes * aa
|
||||
|
||||
|
||||
def scale_function(x):
|
||||
return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7
|
||||
|
||||
|
||||
def extract_loudness(signal, sampling_rate, block_size, n_fft=2048):
|
||||
S = li.stft(
|
||||
signal,
|
||||
n_fft=n_fft,
|
||||
hop_length=block_size,
|
||||
win_length=n_fft,
|
||||
center=True,
|
||||
)
|
||||
S = np.log(abs(S) + 1e-7)
|
||||
f = li.fft_frequencies(sampling_rate, n_fft)
|
||||
a_weight = li.A_weighting(f)
|
||||
|
||||
S = S + a_weight.reshape(-1, 1)
|
||||
|
||||
S = np.mean(S, 0)[..., :-1]
|
||||
|
||||
return S
|
||||
|
||||
|
||||
def extract_pitch(signal, sampling_rate, block_size):
|
||||
length = signal.shape[-1] // block_size
|
||||
f0 = crepe.predict(
|
||||
signal,
|
||||
sampling_rate,
|
||||
step_size=int(1000 * block_size / sampling_rate),
|
||||
verbose=1,
|
||||
center=True,
|
||||
viterbi=True,
|
||||
)
|
||||
f0 = f0[1].reshape(-1)[:-1]
|
||||
|
||||
if f0.shape[-1] != length:
|
||||
f0 = np.interp(
|
||||
np.linspace(0, 1, length, endpoint=False),
|
||||
np.linspace(0, 1, f0.shape[-1], endpoint=False),
|
||||
f0,
|
||||
)
|
||||
|
||||
return f0
|
||||
|
||||
|
||||
def mlp(in_size, hidden_size, n_layers):
|
||||
channels = [in_size] + (n_layers) * [hidden_size]
|
||||
net = []
|
||||
for i in range(n_layers):
|
||||
net.append(nn.Linear(channels[i], channels[i + 1]))
|
||||
net.append(nn.LayerNorm(channels[i + 1]))
|
||||
net.append(nn.LeakyReLU())
|
||||
return nn.Sequential(*net)
|
||||
|
||||
|
||||
def gru(n_input, hidden_size):
|
||||
return nn.GRU(n_input * hidden_size, hidden_size, batch_first=True)
|
||||
|
||||
|
||||
def harmonic_synth(pitch, amplitudes, sampling_rate):
|
||||
n_harmonic = amplitudes.shape[-1]
|
||||
omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1)
|
||||
omegas = omega * torch.arange(1, n_harmonic + 1).to(omega)
|
||||
signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True)
|
||||
return signal
|
||||
|
||||
|
||||
def amp_to_impulse_response(amp, target_size):
|
||||
amp = torch.stack([amp, torch.zeros_like(amp)], -1)
|
||||
amp = torch.view_as_complex(amp)
|
||||
amp = fft.irfft(amp)
|
||||
|
||||
filter_size = amp.shape[-1]
|
||||
|
||||
amp = torch.roll(amp, filter_size // 2, -1)
|
||||
win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)
|
||||
|
||||
amp = amp * win
|
||||
|
||||
amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size)))
|
||||
amp = torch.roll(amp, -filter_size // 2, -1)
|
||||
|
||||
return amp
|
||||
|
||||
|
||||
def fft_convolve(signal, kernel):
|
||||
signal = nn.functional.pad(signal, (0, signal.shape[-1]))
|
||||
kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))
|
||||
|
||||
output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
|
||||
output = output[..., output.shape[-1] // 2:]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
|
||||
if win_type == 'None' or win_type is None:
|
||||
window = np.ones(win_len)
|
||||
else:
|
||||
window = get_window(win_type, win_len, fftbins=True)#**0.5
|
||||
|
||||
N = fft_len
|
||||
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
|
||||
real_kernel = np.real(fourier_basis)
|
||||
imag_kernel = np.imag(fourier_basis)
|
||||
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
|
||||
|
||||
if invers :
|
||||
kernel = np.linalg.pinv(kernel).T
|
||||
|
||||
kernel = kernel*window
|
||||
kernel = kernel[:, None, :]
|
||||
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32))
|
||||
|
|
@ -5,182 +5,187 @@ import scipy
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
import modules.commons as commons
|
||||
from modules.commons import init_weights, get_padding
|
||||
|
||||
from modules.transforms import piecewise_rational_quadratic_transform
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers-1):
|
||||
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DDSConv(nn.Module):
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size ** i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
||||
groups=channels, dilation=dilation, padding=padding
|
||||
))
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size ** i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
||||
groups=channels, dilation=dilation, padding=padding
|
||||
))
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
||||
super(WN, self).__init__()
|
||||
assert(kernel_size % 2 == 1)
|
||||
self.hidden_channels =hidden_channels
|
||||
self.kernel_size = kernel_size,
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
self.p_dropout = p_dropout
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, n_speakers=0, spk_channels=0,
|
||||
p_dropout=0):
|
||||
super(WN, self).__init__()
|
||||
assert (kernel_size % 2 == 1)
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size,
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_speakers = n_speakers
|
||||
self.spk_channels = spk_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
||||
if n_speakers > 0:
|
||||
cond_layer = torch.nn.Conv1d(spk_channels, 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate ** i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
|
||||
dilation=dilation, padding=padding)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
self.in_layers.append(in_layer)
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate ** i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
|
||||
dilation=dilation, padding=padding)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x_in,
|
||||
g_l,
|
||||
n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x_in,
|
||||
g_l,
|
||||
n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:,:self.hidden_channels,:]
|
||||
x = (x + res_acts) * x_mask
|
||||
output = output + res_skip_acts[:,self.hidden_channels:,:]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:, :self.hidden_channels, :]
|
||||
x = (x + res_acts) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
def remove_weight_norm(self):
|
||||
if self.n_speakers > 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
|
@ -256,87 +261,193 @@ class ResBlock2(torch.nn.Module):
|
|||
|
||||
|
||||
class Log(nn.Module):
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = torch.exp(x) * x_mask
|
||||
return x
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = torch.exp(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Flip(nn.Module):
|
||||
def forward(self, x, *args, reverse=False, **kwargs):
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
def forward(self, x, *args, reverse=False, **kwargs):
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels,1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels,1))
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels, 1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
||||
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1,2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingLayer(nn.Module):
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
p_dropout=0,
|
||||
gin_channels=0,
|
||||
mean_only=False):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.half_channels = channels // 2
|
||||
self.mean_only = mean_only
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
p_dropout=0,
|
||||
n_speakers=0,
|
||||
spk_channels=0,
|
||||
mean_only=False):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.half_channels = channels // 2
|
||||
self.mean_only = mean_only
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, n_speakers=n_speakers,
|
||||
spk_channels=spk_channels)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
stats = self.post(h) * x_mask
|
||||
if not self.mean_only:
|
||||
m, logs = torch.split(stats, [self.half_channels]*2, 1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
stats = self.post(h) * x_mask
|
||||
if not self.mean_only:
|
||||
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
|
||||
if not reverse:
|
||||
x1 = m + x1 * torch.exp(logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
logdet = torch.sum(logs, [1,2])
|
||||
return x, logdet
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
||||
if not reverse:
|
||||
x1 = m + x1 * torch.exp(logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
logdet = torch.sum(logs, [1, 2])
|
||||
return x, logdet
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
n_flows=4,
|
||||
n_speakers=0,
|
||||
gin_channels=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
|
||||
n_speakers=n_speakers, spk_channels=gin_channels, mean_only=True))
|
||||
self.flows.append(Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
|
||||
class ConvFlow(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.num_bins = num_bins
|
||||
self.tail_bound = tail_bound
|
||||
self.half_channels = in_channels // 2
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
|
||||
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0)
|
||||
h = self.convs(h, x_mask, g=g)
|
||||
h = self.proj(h) * x_mask
|
||||
|
||||
b, c, t = x0.shape
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
|
||||
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_derivatives = h[..., 2 * self.num_bins:]
|
||||
|
||||
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=reverse,
|
||||
tails='linear',
|
||||
tail_bound=self.tail_bound
|
||||
)
|
||||
|
||||
x = torch.cat([x0, x1], 1) * x_mask
|
||||
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
||||
if not reverse:
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class ResStack(nn.Module):
|
||||
def __init__(self, channel, kernel_size=3, base=3, nums=4):
|
||||
super(ResStack, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel,
|
||||
kernel_size=kernel_size, dilation=base ** i, padding=base ** i)),
|
||||
nn.LeakyReLU(),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel,
|
||||
kernel_size=kernel_size, dilation=1, padding=1)),
|
||||
)
|
||||
for i in range(nums)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = x + layer(x)
|
||||
return x
|
||||
|
|
|
@ -0,0 +1,512 @@
|
|||
from librosa.util import pad_center, tiny
|
||||
from scipy.signal import get_window
|
||||
from torch import Tensor
|
||||
from torch.autograd import Variable
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import librosa.util as librosa_util
|
||||
import math
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import warnings
|
||||
|
||||
|
||||
def create_fb_matrix(
|
||||
n_freqs: int,
|
||||
f_min: float,
|
||||
f_max: float,
|
||||
n_mels: int,
|
||||
sample_rate: int,
|
||||
norm: Optional[str] = None
|
||||
) -> Tensor:
|
||||
r"""Create a frequency bin conversion matrix.
|
||||
|
||||
Args:
|
||||
n_freqs (int): Number of frequencies to highlight/apply
|
||||
f_min (float): Minimum frequency (Hz)
|
||||
f_max (float): Maximum frequency (Hz)
|
||||
n_mels (int): Number of mel filterbanks
|
||||
sample_rate (int): Sample rate of the audio waveform
|
||||
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||
(area normalization). (Default: ``None``)
|
||||
|
||||
Returns:
|
||||
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
|
||||
meaning number of frequencies to highlight/apply to x the number of filterbanks.
|
||||
Each column is a filterbank so that assuming there is a matrix A of
|
||||
size (..., ``n_freqs``), the applied result would be
|
||||
``A * create_fb_matrix(A.size(-1), ...)``.
|
||||
"""
|
||||
|
||||
if norm is not None and norm != "slaney":
|
||||
raise ValueError("norm must be one of None or 'slaney'")
|
||||
|
||||
# freq bins
|
||||
# Equivalent filterbank construction by Librosa
|
||||
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
||||
|
||||
# calculate mel freq bins
|
||||
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
|
||||
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
|
||||
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
|
||||
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
|
||||
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
|
||||
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
|
||||
# calculate the difference between each mel point and each stft freq point in hertz
|
||||
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
|
||||
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
|
||||
# create overlapping triangles
|
||||
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
|
||||
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
|
||||
fb = torch.min(down_slopes, up_slopes)
|
||||
fb = torch.clamp(fb, 1e-6, 1)
|
||||
|
||||
if norm is not None and norm == "slaney":
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
|
||||
fb *= enorm.unsqueeze(0)
|
||||
return fb
|
||||
|
||||
|
||||
def lfilter(
|
||||
waveform: Tensor,
|
||||
a_coeffs: Tensor,
|
||||
b_coeffs: Tensor,
|
||||
clamp: bool = True,
|
||||
) -> Tensor:
|
||||
r"""Perform an IIR filter by evaluating difference equation.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
|
||||
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
|
||||
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
|
||||
Must be same size as b_coeffs (pad with 0's as necessary).
|
||||
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``.
|
||||
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
|
||||
Must be same size as a_coeffs (pad with 0's as necessary).
|
||||
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
|
||||
|
||||
Returns:
|
||||
Tensor: Waveform with dimension of ``(..., time)``.
|
||||
"""
|
||||
# pack batch
|
||||
shape = waveform.size()
|
||||
waveform = waveform.reshape(-1, shape[-1])
|
||||
|
||||
assert (a_coeffs.size(0) == b_coeffs.size(0))
|
||||
assert (len(waveform.size()) == 2)
|
||||
assert (waveform.device == a_coeffs.device)
|
||||
assert (b_coeffs.device == a_coeffs.device)
|
||||
|
||||
device = waveform.device
|
||||
dtype = waveform.dtype
|
||||
n_channel, n_sample = waveform.size()
|
||||
n_order = a_coeffs.size(0)
|
||||
n_sample_padded = n_sample + n_order - 1
|
||||
assert (n_order > 0)
|
||||
|
||||
# Pad the input and create output
|
||||
padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
|
||||
padded_waveform[:, (n_order - 1):] = waveform
|
||||
padded_output_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
|
||||
|
||||
# Set up the coefficients matrix
|
||||
# Flip coefficients' order
|
||||
a_coeffs_flipped = a_coeffs.flip(0)
|
||||
b_coeffs_flipped = b_coeffs.flip(0)
|
||||
|
||||
# calculate windowed_input_signal in parallel
|
||||
# create indices of original with shape (n_channel, n_order, n_sample)
|
||||
window_idxs = torch.arange(n_sample, device=device).unsqueeze(0) + torch.arange(n_order, device=device).unsqueeze(1)
|
||||
window_idxs = window_idxs.repeat(n_channel, 1, 1)
|
||||
window_idxs += (torch.arange(n_channel, device=device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
|
||||
window_idxs = window_idxs.long()
|
||||
# (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
|
||||
input_signal_windows = torch.matmul(b_coeffs_flipped, torch.take(padded_waveform, window_idxs))
|
||||
|
||||
input_signal_windows.div_(a_coeffs[0])
|
||||
a_coeffs_flipped.div_(a_coeffs[0])
|
||||
for i_sample, o0 in enumerate(input_signal_windows.t()):
|
||||
windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
|
||||
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
|
||||
padded_output_waveform[:, i_sample + n_order - 1] = o0
|
||||
|
||||
output = padded_output_waveform[:, (n_order - 1):]
|
||||
|
||||
if clamp:
|
||||
output = torch.clamp(output, min=-1., max=1.)
|
||||
|
||||
# unpack batch
|
||||
output = output.reshape(shape[:-1] + output.shape[-1:])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def biquad(
|
||||
waveform: Tensor,
|
||||
b0: float,
|
||||
b1: float,
|
||||
b2: float,
|
||||
a0: float,
|
||||
a1: float,
|
||||
a2: float
|
||||
) -> Tensor:
|
||||
r"""Perform a biquad filter of input tensor. Initial conditions set to 0.
|
||||
https://en.wikipedia.org/wiki/Digital_biquad_filter
|
||||
|
||||
Args:
|
||||
waveform (Tensor): audio waveform of dimension of `(..., time)`
|
||||
b0 (float): numerator coefficient of current input, x[n]
|
||||
b1 (float): numerator coefficient of input one time step ago x[n-1]
|
||||
b2 (float): numerator coefficient of input two time steps ago x[n-2]
|
||||
a0 (float): denominator coefficient of current output y[n], typically 1
|
||||
a1 (float): denominator coefficient of current output y[n-1]
|
||||
a2 (float): denominator coefficient of current output y[n-2]
|
||||
|
||||
Returns:
|
||||
Tensor: Waveform with dimension of `(..., time)`
|
||||
"""
|
||||
|
||||
device = waveform.device
|
||||
dtype = waveform.dtype
|
||||
|
||||
output_waveform = lfilter(
|
||||
waveform,
|
||||
torch.tensor([a0, a1, a2], dtype=dtype, device=device),
|
||||
torch.tensor([b0, b1, b2], dtype=dtype, device=device)
|
||||
)
|
||||
return output_waveform
|
||||
|
||||
|
||||
|
||||
def _dB2Linear(x: float) -> float:
|
||||
return math.exp(x * math.log(10) / 20.0)
|
||||
|
||||
|
||||
def highpass_biquad(
|
||||
waveform: Tensor,
|
||||
sample_rate: int,
|
||||
cutoff_freq: float,
|
||||
Q: float = 0.707
|
||||
) -> Tensor:
|
||||
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): audio waveform of dimension of `(..., time)`
|
||||
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
|
||||
cutoff_freq (float): filter cutoff frequency
|
||||
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
|
||||
|
||||
Returns:
|
||||
Tensor: Waveform dimension of `(..., time)`
|
||||
"""
|
||||
w0 = 2 * math.pi * cutoff_freq / sample_rate
|
||||
alpha = math.sin(w0) / 2. / Q
|
||||
|
||||
b0 = (1 + math.cos(w0)) / 2
|
||||
b1 = -1 - math.cos(w0)
|
||||
b2 = b0
|
||||
a0 = 1 + alpha
|
||||
a1 = -2 * math.cos(w0)
|
||||
a2 = 1 - alpha
|
||||
return biquad(waveform, b0, b1, b2, a0, a1, a2)
|
||||
|
||||
|
||||
|
||||
def lowpass_biquad(
|
||||
waveform: Tensor,
|
||||
sample_rate: int,
|
||||
cutoff_freq: float,
|
||||
Q: float = 0.707
|
||||
) -> Tensor:
|
||||
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
|
||||
|
||||
Args:
|
||||
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
|
||||
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
|
||||
cutoff_freq (float): filter cutoff frequency
|
||||
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
|
||||
|
||||
Returns:
|
||||
Tensor: Waveform of dimension of `(..., time)`
|
||||
"""
|
||||
w0 = 2 * math.pi * cutoff_freq / sample_rate
|
||||
alpha = math.sin(w0) / 2 / Q
|
||||
|
||||
b0 = (1 - math.cos(w0)) / 2
|
||||
b1 = 1 - math.cos(w0)
|
||||
b2 = b0
|
||||
a0 = 1 + alpha
|
||||
a1 = -2 * math.cos(w0)
|
||||
a2 = 1 - alpha
|
||||
return biquad(waveform, b0, b1, b2, a0, a1, a2)
|
||||
|
||||
|
||||
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
||||
n_fft=800, dtype=np.float32, norm=None):
|
||||
"""
|
||||
# from librosa 0.6
|
||||
Compute the sum-square envelope of a window function at a given hop length.
|
||||
|
||||
This is used to estimate modulation effects induced by windowing
|
||||
observations in short-time fourier transforms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
window : string, tuple, number, callable, or list-like
|
||||
Window specification, as in `get_window`
|
||||
|
||||
n_frames : int > 0
|
||||
The number of analysis frames
|
||||
|
||||
hop_length : int > 0
|
||||
The number of samples to advance between frames
|
||||
|
||||
win_length : [optional]
|
||||
The length of the window function. By default, this matches `n_fft`.
|
||||
|
||||
n_fft : int > 0
|
||||
The length of each analysis frame.
|
||||
|
||||
dtype : np.dtype
|
||||
The data type of the output
|
||||
|
||||
Returns
|
||||
-------
|
||||
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
||||
The sum-squared envelope of the window function
|
||||
"""
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
n = n_fft + hop_length * (n_frames - 1)
|
||||
x = np.zeros(n, dtype=dtype)
|
||||
|
||||
# Compute the squared window at the desired length
|
||||
win_sq = get_window(window, win_length, fftbins=True)
|
||||
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
||||
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
||||
|
||||
# Fill the envelope
|
||||
for i in range(n_frames):
|
||||
sample = i * hop_length
|
||||
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||
return x
|
||||
|
||||
|
||||
class MelScale(torch.nn.Module):
|
||||
r"""Turn a normal STFT into a mel frequency STFT, using a conversion
|
||||
matrix. This uses triangular filter banks.
|
||||
|
||||
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
|
||||
|
||||
Args:
|
||||
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
|
||||
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
|
||||
f_min (float, optional): Minimum frequency. (Default: ``0.``)
|
||||
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
|
||||
n_stft (int, optional): Number of bins in STFT. Calculated from first input
|
||||
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
|
||||
"""
|
||||
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
|
||||
|
||||
def __init__(self,
|
||||
n_mels: int = 128,
|
||||
sample_rate: int = 24000,
|
||||
f_min: float = 0.,
|
||||
f_max: Optional[float] = None,
|
||||
n_stft: Optional[int] = None) -> None:
|
||||
super(MelScale, self).__init__()
|
||||
self.n_mels = n_mels
|
||||
self.sample_rate = sample_rate
|
||||
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
|
||||
self.f_min = f_min
|
||||
|
||||
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
|
||||
|
||||
fb = torch.empty(0) if n_stft is None else create_fb_matrix(
|
||||
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
|
||||
self.register_buffer('fb', fb)
|
||||
|
||||
def forward(self, specgram: Tensor) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
|
||||
|
||||
Returns:
|
||||
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
|
||||
"""
|
||||
|
||||
# pack batch
|
||||
shape = specgram.size()
|
||||
specgram = specgram.reshape(-1, shape[-2], shape[-1])
|
||||
|
||||
if self.fb.numel() == 0:
|
||||
tmp_fb = create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
|
||||
# Attributes cannot be reassigned outside __init__ so workaround
|
||||
self.fb.resize_(tmp_fb.size())
|
||||
self.fb.copy_(tmp_fb)
|
||||
|
||||
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
|
||||
# -> (channel, time, n_mels).transpose(...)
|
||||
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
|
||||
|
||||
# unpack batch
|
||||
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
|
||||
|
||||
return mel_specgram
|
||||
|
||||
|
||||
class TorchSTFT(torch.nn.Module):
|
||||
def __init__(self, fft_size, hop_size, win_size,
|
||||
normalized=False, domain='linear',
|
||||
mel_scale=False, ref_level_db=20, min_level_db=-100):
|
||||
super().__init__()
|
||||
self.fft_size = fft_size
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.ref_level_db = ref_level_db
|
||||
self.min_level_db = min_level_db
|
||||
self.window = torch.hann_window(win_size)
|
||||
self.normalized = normalized
|
||||
self.domain = domain
|
||||
self.mel_scale = MelScale(n_mels=(fft_size // 2 + 1),
|
||||
n_stft=(fft_size // 2 + 1)) if mel_scale else None
|
||||
|
||||
def transform(self, x):
|
||||
x_stft = torch.stft(x.to(torch.float32), self.fft_size, self.hop_size, self.win_size,
|
||||
self.window.type_as(x), normalized=self.normalized)
|
||||
real = x_stft[..., 0]
|
||||
imag = x_stft[..., 1]
|
||||
mag = torch.clamp(real ** 2 + imag ** 2, min=1e-7)
|
||||
mag = torch.sqrt(mag)
|
||||
phase = torch.atan2(imag, real)
|
||||
|
||||
if self.mel_scale is not None:
|
||||
mag = self.mel_scale(mag)
|
||||
|
||||
if self.domain == 'log':
|
||||
mag = 20 * torch.log10(mag) - self.ref_level_db
|
||||
mag = torch.clamp((mag - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||
return mag, phase
|
||||
elif self.domain == 'linear':
|
||||
return mag, phase
|
||||
elif self.domain == 'double':
|
||||
log_mag = 20 * torch.log10(mag) - self.ref_level_db
|
||||
log_mag = torch.clamp((log_mag - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||
return torch.cat((mag, log_mag), dim=1), phase
|
||||
|
||||
def complex(self, x):
|
||||
x_stft = torch.stft(x, self.fft_size, self.hop_size, self.win_size,
|
||||
self.window.type_as(x), normalized=self.normalized)
|
||||
real = x_stft[..., 0]
|
||||
imag = x_stft[..., 1]
|
||||
return real, imag
|
||||
|
||||
|
||||
|
||||
class STFT(torch.nn.Module):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
||||
window='hann'):
|
||||
super(STFT, self).__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.forward_transform = None
|
||||
scale = self.filter_length / self.hop_length
|
||||
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||
|
||||
cutoff = int((self.filter_length / 2 + 1))
|
||||
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
||||
np.imag(fourier_basis[:cutoff, :])])
|
||||
|
||||
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
||||
inverse_basis = torch.FloatTensor(
|
||||
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
||||
|
||||
if window is not None:
|
||||
assert(filter_length >= win_length)
|
||||
# get window and zero center pad it to filter_length
|
||||
fft_window = get_window(window, win_length, fftbins=True)
|
||||
fft_window = pad_center(fft_window, filter_length)
|
||||
fft_window = torch.from_numpy(fft_window).float()
|
||||
|
||||
# window the bases
|
||||
forward_basis *= fft_window
|
||||
inverse_basis *= fft_window
|
||||
|
||||
self.register_buffer('forward_basis', forward_basis.float())
|
||||
self.register_buffer('inverse_basis', inverse_basis.float())
|
||||
|
||||
def transform(self, input_data):
|
||||
num_batches = input_data.size(0)
|
||||
num_samples = input_data.size(1)
|
||||
|
||||
self.num_samples = num_samples
|
||||
|
||||
# similar to librosa, reflect-pad the input
|
||||
input_data = input_data.view(num_batches, 1, num_samples)
|
||||
input_data = F.pad(
|
||||
input_data.unsqueeze(1),
|
||||
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
||||
mode='reflect')
|
||||
input_data = input_data.squeeze(1)
|
||||
|
||||
forward_transform = F.conv1d(
|
||||
input_data,
|
||||
Variable(self.forward_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0)
|
||||
|
||||
cutoff = int((self.filter_length / 2) + 1)
|
||||
real_part = forward_transform[:, :cutoff, :]
|
||||
imag_part = forward_transform[:, cutoff:, :]
|
||||
|
||||
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
||||
phase = torch.autograd.Variable(
|
||||
torch.atan2(imag_part.data, real_part.data))
|
||||
|
||||
return magnitude, phase
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
recombine_magnitude_phase = torch.cat(
|
||||
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
||||
|
||||
inverse_transform = F.conv_transpose1d(
|
||||
recombine_magnitude_phase,
|
||||
Variable(self.inverse_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0)
|
||||
|
||||
if self.window is not None:
|
||||
window_sum = window_sumsquare(
|
||||
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
||||
win_length=self.win_length, n_fft=self.filter_length,
|
||||
dtype=np.float32)
|
||||
# remove modulation effects
|
||||
approx_nonzero_indices = torch.from_numpy(
|
||||
np.where(window_sum > tiny(window_sum))[0])
|
||||
window_sum = torch.autograd.Variable(
|
||||
torch.from_numpy(window_sum), requires_grad=False)
|
||||
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
||||
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||
|
||||
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
||||
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
||||
|
||||
return inverse_transform
|
||||
|
||||
def forward(self, input_data):
|
||||
self.magnitude, self.phase = self.transform(input_data)
|
||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||
return reconstruction
|
||||
|
|
@ -0,0 +1,193 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||
|
||||
|
||||
def piecewise_rational_quadratic_transform(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
|
||||
if tails is None:
|
||||
spline_fn = rational_quadratic_spline
|
||||
spline_kwargs = {}
|
||||
else:
|
||||
spline_fn = unconstrained_rational_quadratic_spline
|
||||
spline_kwargs = {
|
||||
'tails': tails,
|
||||
'tail_bound': tail_bound
|
||||
}
|
||||
|
||||
outputs, logabsdet = spline_fn(
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def searchsorted(bin_locations, inputs, eps=1e-6):
|
||||
bin_locations[..., -1] += eps
|
||||
return torch.sum(
|
||||
inputs[..., None] >= bin_locations,
|
||||
dim=-1
|
||||
) - 1
|
||||
|
||||
|
||||
def unconstrained_rational_quadratic_spline(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails='linear',
|
||||
tail_bound=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||
outside_interval_mask = ~inside_interval_mask
|
||||
|
||||
outputs = torch.zeros_like(inputs)
|
||||
logabsdet = torch.zeros_like(inputs)
|
||||
|
||||
if tails == 'linear':
|
||||
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||
unnormalized_derivatives[..., 0] = constant
|
||||
unnormalized_derivatives[..., -1] = constant
|
||||
|
||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||
logabsdet[outside_interval_mask] = 0
|
||||
else:
|
||||
raise RuntimeError('{} tails are not implemented.'.format(tails))
|
||||
|
||||
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
||||
inputs=inputs[inside_interval_mask],
|
||||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||
inverse=inverse,
|
||||
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative
|
||||
)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
def rational_quadratic_spline(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0., right=1., bottom=0., top=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||
raise ValueError('Input to a transform is not within its domain')
|
||||
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
|
||||
if min_bin_width * num_bins > 1.0:
|
||||
raise ValueError('Minimal bin width too large for the number of bins')
|
||||
if min_bin_height * num_bins > 1.0:
|
||||
raise ValueError('Minimal bin height too large for the number of bins')
|
||||
|
||||
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||
cumwidths = torch.cumsum(widths, dim=-1)
|
||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
|
||||
cumwidths = (right - left) * cumwidths + left
|
||||
cumwidths[..., 0] = left
|
||||
cumwidths[..., -1] = right
|
||||
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||
|
||||
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||
|
||||
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||
cumheights = torch.cumsum(heights, dim=-1)
|
||||
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
|
||||
cumheights = (top - bottom) * cumheights + bottom
|
||||
cumheights[..., 0] = bottom
|
||||
cumheights[..., -1] = top
|
||||
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||
|
||||
if inverse:
|
||||
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
||||
else:
|
||||
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
||||
|
||||
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||
delta = heights / widths
|
||||
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
if inverse:
|
||||
a = (((inputs - input_cumheights) * (input_derivatives
|
||||
+ input_derivatives_plus_one
|
||||
- 2 * input_delta)
|
||||
+ input_heights * (input_delta - input_derivatives)))
|
||||
b = (input_heights * input_derivatives
|
||||
- (inputs - input_cumheights) * (input_derivatives
|
||||
+ input_derivatives_plus_one
|
||||
- 2 * input_delta))
|
||||
c = - input_delta * (inputs - input_cumheights)
|
||||
|
||||
discriminant = b.pow(2) - 4 * a * c
|
||||
assert (discriminant >= 0).all()
|
||||
|
||||
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||
outputs = root * input_bin_widths + input_cumwidths
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta)
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - root).pow(2))
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, -logabsdet
|
||||
else:
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (input_delta * theta.pow(2)
|
||||
+ input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - theta).pow(2))
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, logabsdet
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue