Add files via upload

This commit is contained in:
Ναρουσέ·μ·γιουμεμί·Χινακάννα 2023-03-10 19:43:07 +08:00 committed by GitHub
parent deadcc671b
commit fe1f733aff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 2305 additions and 454 deletions

View File

@ -1,62 +1,103 @@
{
"train": {
"log_interval": 200,
"eval_interval": 800,
"log_interval": 50,
"eval_interval": 1000,
"seed": 1234,
"port": 8001,
"epochs": 10000,
"learning_rate": 0.0001,
"learning_rate": 0.0002,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 6,
"accumulation_steps": 1,
"fp16_run": false,
"lr_decay": 0.999875,
"lr_decay": 0.998,
"segment_size": 10240,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"use_sr": true,
"max_speclen": 512,
"port": "8001",
"keep_ckpts": 3
"keep_ckpts":4
},
"data": {
"training_files": "filelists/train.txt",
"validation_files": "filelists/val.txt",
"data_dir": "dataset",
"dataset_type": "SingDataset",
"collate_type": "SingCollate",
"training_filelist": "filelists/train.txt",
"validation_filelist": "filelists/val.txt",
"max_wav_value": 32768.0,
"sampling_rate": 44100,
"filter_length": 2048,
"n_fft": 2048,
"fmin": 0,
"fmax": 22050,
"hop_length": 512,
"win_length": 2048,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": 22050
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"upsample_rates": [ 8, 8, 2, 2, 2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16, 4, 4, 4],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 256,
"ssl_dim": 256,
"win_size": 2048,
"acoustic_dim": 80,
"c_dim": 256,
"min_level_db": -115,
"ref_level_db": 20,
"min_db": -115,
"max_abs_value": 4.0,
"n_speakers": 200
},
"model": {
"hidden_channels": 192,
"spk_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 4,
"kernel_size": 3,
"p_dropout": 0.1,
"prior_hidden_channels": 192,
"prior_filter_channels": 768,
"prior_n_heads": 2,
"prior_n_layers": 4,
"prior_kernel_size": 3,
"prior_p_dropout": 0.1,
"resblock": "1",
"use_spectral_norm": false,
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
8,
8,
4,
2
],
"upsample_initial_channel": 256,
"upsample_kernel_sizes": [
16,
16,
8,
4
],
"n_harmonic": 64,
"n_bands": 65
},
"spk": {
"nyaru": 0,
"jishuang": 0,
"huiyu": 1,
"nen": 2,
"paimon": 3,

View File

@ -127,9 +127,8 @@ class Svc(object):
def load_model(self):
# 获取模型配置
self.net_g_ms = SynthesizerTrn(
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
**self.hps_ms.model)
self.hps_ms
)
_ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
if "half" in self.net_g_path and torch.cuda.is_available():
_ = self.net_g_ms.half().eval().to(self.dev)
@ -167,17 +166,14 @@ class Svc(object):
cluster_infer_ratio=0,
auto_predict_f0=False,
noice_scale=0.4):
speaker_id = self.spk2id.__dict__.get(speaker)
if not speaker_id and type(speaker) is int:
if len(self.spk2id.__dict__) >= speaker:
speaker_id = speaker
speaker_id = self.spk2id[speaker]
sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
c, f0, uv = self.get_unit_f0(raw_path, tran, cluster_infer_ratio, speaker)
if "half" in self.net_g_path and torch.cuda.is_available():
c = c.half()
with torch.no_grad():
start = time.time()
audio = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale)[0,0].data.float()
audio = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale)[0][0,0].data.float()
use_time = time.time() - start
print("vits use time:{}".format(use_time))
return audio, audio.shape[-1]

99
modules/audio.py Normal file
View File

@ -0,0 +1,99 @@
import numpy as np
from numpy import linalg as LA
import librosa
from scipy.io import wavfile
import soundfile as sf
import librosa.filters
def load_wav(wav_path, raw_sr, target_sr=16000, win_size=800, hop_size=200):
audio = librosa.core.load(wav_path, sr=raw_sr)[0]
if raw_sr != target_sr:
audio = librosa.core.resample(audio,
raw_sr,
target_sr,
res_type='kaiser_best')
target_length = (audio.size // hop_size +
win_size // hop_size) * hop_size
pad_len = (target_length - audio.size) // 2
if audio.size % 2 == 0:
audio = np.pad(audio, (pad_len, pad_len), mode='reflect')
else:
audio = np.pad(audio, (pad_len, pad_len + 1), mode='reflect')
return audio
def save_wav(wav, path, sample_rate, norm=False):
if norm:
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
wavfile.write(path, sample_rate, wav.astype(np.int16))
else:
sf.write(path, wav, sample_rate)
_mel_basis = None
_inv_mel_basis = None
def _build_mel_basis(hparams):
assert hparams.fmax <= hparams.sampling_rate // 2
return librosa.filters.mel(hparams.sampling_rate,
hparams.n_fft,
n_mels=hparams.acoustic_dim,
fmin=hparams.fmin,
fmax=hparams.fmax)
def _linear_to_mel(spectogram, hparams):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis(hparams)
return np.dot(_mel_basis, spectogram)
def _mel_to_linear(mel_spectrogram, hparams):
global _inv_mel_basis
if _inv_mel_basis is None:
_inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
def _stft(y, hparams):
return librosa.stft(y=y,
n_fft=hparams.n_fft,
hop_length=hparams.hop_length,
win_length=hparams.win_size)
def _amp_to_db(x, hparams):
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _normalize(S, hparams):
return hparams.max_abs_value * np.clip(((S - hparams.min_db) /
(-hparams.min_db)), 0, 1)
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _stft(y, hparams):
return librosa.stft(y=y,
n_fft=hparams.n_fft,
hop_length=hparams.hop_length,
win_length=hparams.win_size)
def _istft(y, hparams):
return librosa.istft(y,
hop_length=hparams.hop_length,
win_length=hparams.win_size)
def melspectrogram(wav, hparams):
D = _stft(wav, hparams)
S = _amp_to_db(_linear_to_mel(np.abs(D), hparams),
hparams) - hparams.ref_level_db
return _normalize(S, hparams)

189
modules/ddsp.py Normal file
View File

@ -0,0 +1,189 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.fft as fft
import numpy as np
import librosa as li
import math
from scipy.signal import get_window
def safe_log(x):
return torch.log(x + 1e-7)
@torch.no_grad()
def mean_std_loudness(dataset):
mean = 0
std = 0
n = 0
for _, _, l in dataset:
n += 1
mean += (l.mean().item() - mean) / n
std += (l.std().item() - std) / n
return mean, std
def multiscale_fft(signal, scales, overlap):
stfts = []
for s in scales:
S = torch.stft(
signal,
s,
int(s * (1 - overlap)),
s,
torch.hann_window(s).to(signal),
True,
normalized=True,
return_complex=True,
).abs()
stfts.append(S)
return stfts
def resample(x, factor: int):
batch, frame, channel = x.shape
x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame)
window = torch.hann_window(
factor * 2,
dtype=x.dtype,
device=x.device,
).reshape(1, 1, -1)
y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x)
y[..., ::factor] = x
y[..., -1:] = x[..., -1:]
y = torch.nn.functional.pad(y, [factor, factor])
y = torch.nn.functional.conv1d(y, window)[..., :-1]
y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1)
return y
def upsample(signal, factor):
signal = signal.permute(0, 2, 1)
signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor)
return signal.permute(0, 2, 1)
def remove_above_nyquist(amplitudes, pitch, sampling_rate):
n_harm = amplitudes.shape[-1]
pitches = pitch * torch.arange(1, n_harm + 1).to(pitch)
aa = (pitches < sampling_rate / 2).float() + 1e-4
return amplitudes * aa
def scale_function(x):
return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7
def extract_loudness(signal, sampling_rate, block_size, n_fft=2048):
S = li.stft(
signal,
n_fft=n_fft,
hop_length=block_size,
win_length=n_fft,
center=True,
)
S = np.log(abs(S) + 1e-7)
f = li.fft_frequencies(sampling_rate, n_fft)
a_weight = li.A_weighting(f)
S = S + a_weight.reshape(-1, 1)
S = np.mean(S, 0)[..., :-1]
return S
def extract_pitch(signal, sampling_rate, block_size):
length = signal.shape[-1] // block_size
f0 = crepe.predict(
signal,
sampling_rate,
step_size=int(1000 * block_size / sampling_rate),
verbose=1,
center=True,
viterbi=True,
)
f0 = f0[1].reshape(-1)[:-1]
if f0.shape[-1] != length:
f0 = np.interp(
np.linspace(0, 1, length, endpoint=False),
np.linspace(0, 1, f0.shape[-1], endpoint=False),
f0,
)
return f0
def mlp(in_size, hidden_size, n_layers):
channels = [in_size] + (n_layers) * [hidden_size]
net = []
for i in range(n_layers):
net.append(nn.Linear(channels[i], channels[i + 1]))
net.append(nn.LayerNorm(channels[i + 1]))
net.append(nn.LeakyReLU())
return nn.Sequential(*net)
def gru(n_input, hidden_size):
return nn.GRU(n_input * hidden_size, hidden_size, batch_first=True)
def harmonic_synth(pitch, amplitudes, sampling_rate):
n_harmonic = amplitudes.shape[-1]
omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1)
omegas = omega * torch.arange(1, n_harmonic + 1).to(omega)
signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True)
return signal
def amp_to_impulse_response(amp, target_size):
amp = torch.stack([amp, torch.zeros_like(amp)], -1)
amp = torch.view_as_complex(amp)
amp = fft.irfft(amp)
filter_size = amp.shape[-1]
amp = torch.roll(amp, filter_size // 2, -1)
win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)
amp = amp * win
amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size)))
amp = torch.roll(amp, -filter_size // 2, -1)
return amp
def fft_convolve(signal, kernel):
signal = nn.functional.pad(signal, (0, signal.shape[-1]))
kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))
output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
output = output[..., output.shape[-1] // 2:]
return output
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
if win_type == 'None' or win_type is None:
window = np.ones(win_len)
else:
window = get_window(win_type, win_len, fftbins=True)#**0.5
N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
if invers :
kernel = np.linalg.pinv(kernel).T
kernel = kernel*window
kernel = kernel[:, None, :]
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32))

View File

@ -5,182 +5,187 @@ import scipy
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from typing import Any, Optional, Tuple
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm
import modules.commons as commons
from modules.commons import init_weights, get_padding
from modules.transforms import piecewise_rational_quadratic_transform
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers-1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
"""
Dialted and Depth-Separable Convolution
"""
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size ** i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
groups=channels, dilation=dilation, padding=padding
))
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size ** i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
groups=channels, dilation=dilation, padding=padding
))
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
self.hidden_channels =hidden_channels
self.kernel_size = kernel_size,
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, n_speakers=0, spk_channels=0,
p_dropout=0):
super(WN, self).__init__()
assert (kernel_size % 2 == 1)
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size,
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_speakers = n_speakers
self.spk_channels = spk_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
if n_speakers > 0:
cond_layer = torch.nn.Conv1d(spk_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
for i in range(n_layers):
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
for i in range(n_layers):
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
else:
g_l = torch.zeros_like(x_in)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(
x_in,
g_l,
n_channels_tensor)
acts = self.drop(acts)
acts = commons.fused_add_tanh_sigmoid_multiply(
x_in,
g_l,
n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:,:self.hidden_channels,:]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:,self.hidden_channels:,:]
else:
output = output + res_skip_acts
return output * x_mask
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, :self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels:, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
def remove_weight_norm(self):
if self.n_speakers > 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ResBlock1(torch.nn.Module):
@ -256,87 +261,193 @@ class ResBlock2(torch.nn.Module):
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels,1))
self.logs = nn.Parameter(torch.zeros(channels,1))
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1,2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
n_speakers=0,
spk_channels=0,
mean_only=False):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, n_speakers=n_speakers,
spk_channels=spk_channels)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels]*2, 1)
else:
m = stats
logs = torch.zeros_like(m)
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1,2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
n_speakers=0,
gin_channels=0):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
n_speakers=n_speakers, spk_channels=gin_channels, mean_only=True))
self.flows.append(Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins:]
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails='linear',
tail_bound=self.tail_bound
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class ResStack(nn.Module):
def __init__(self, channel, kernel_size=3, base=3, nums=4):
super(ResStack, self).__init__()
self.layers = nn.ModuleList([
nn.Sequential(
nn.LeakyReLU(),
nn.utils.weight_norm(nn.Conv1d(channel, channel,
kernel_size=kernel_size, dilation=base ** i, padding=base ** i)),
nn.LeakyReLU(),
nn.utils.weight_norm(nn.Conv1d(channel, channel,
kernel_size=kernel_size, dilation=1, padding=1)),
)
for i in range(nums)
])
def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return x

512
modules/stft.py Normal file
View File

@ -0,0 +1,512 @@
from librosa.util import pad_center, tiny
from scipy.signal import get_window
from torch import Tensor
from torch.autograd import Variable
from typing import Optional, Tuple
import librosa
import librosa.util as librosa_util
import math
import numpy as np
import scipy
import torch
import torch.nn.functional as F
import warnings
def create_fb_matrix(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None
) -> Tensor:
r"""Create a frequency bin conversion matrix.
Args:
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
``A * create_fb_matrix(A.size(-1), ...)``.
"""
if norm is not None and norm != "slaney":
raise ValueError("norm must be one of None or 'slaney'")
# freq bins
# Equivalent filterbank construction by Librosa
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
# create overlapping triangles
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.min(down_slopes, up_slopes)
fb = torch.clamp(fb, 1e-6, 1)
if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
fb *= enorm.unsqueeze(0)
return fb
def lfilter(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.
Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Tensor: Waveform with dimension of ``(..., time)``.
"""
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
assert (a_coeffs.size(0) == b_coeffs.size(0))
assert (len(waveform.size()) == 2)
assert (waveform.device == a_coeffs.device)
assert (b_coeffs.device == a_coeffs.device)
device = waveform.device
dtype = waveform.dtype
n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(0)
n_sample_padded = n_sample + n_order - 1
assert (n_order > 0)
# Pad the input and create output
padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
padded_waveform[:, (n_order - 1):] = waveform
padded_output_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
# Set up the coefficients matrix
# Flip coefficients' order
a_coeffs_flipped = a_coeffs.flip(0)
b_coeffs_flipped = b_coeffs.flip(0)
# calculate windowed_input_signal in parallel
# create indices of original with shape (n_channel, n_order, n_sample)
window_idxs = torch.arange(n_sample, device=device).unsqueeze(0) + torch.arange(n_order, device=device).unsqueeze(1)
window_idxs = window_idxs.repeat(n_channel, 1, 1)
window_idxs += (torch.arange(n_channel, device=device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
window_idxs = window_idxs.long()
# (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
input_signal_windows = torch.matmul(b_coeffs_flipped, torch.take(padded_waveform, window_idxs))
input_signal_windows.div_(a_coeffs[0])
a_coeffs_flipped.div_(a_coeffs[0])
for i_sample, o0 in enumerate(input_signal_windows.t()):
windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
padded_output_waveform[:, i_sample + n_order - 1] = o0
output = padded_output_waveform[:, (n_order - 1):]
if clamp:
output = torch.clamp(output, min=-1., max=1.)
# unpack batch
output = output.reshape(shape[:-1] + output.shape[-1:])
return output
def biquad(
waveform: Tensor,
b0: float,
b1: float,
b2: float,
a0: float,
a1: float,
a2: float
) -> Tensor:
r"""Perform a biquad filter of input tensor. Initial conditions set to 0.
https://en.wikipedia.org/wiki/Digital_biquad_filter
Args:
waveform (Tensor): audio waveform of dimension of `(..., time)`
b0 (float): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2]
a0 (float): denominator coefficient of current output y[n], typically 1
a1 (float): denominator coefficient of current output y[n-1]
a2 (float): denominator coefficient of current output y[n-2]
Returns:
Tensor: Waveform with dimension of `(..., time)`
"""
device = waveform.device
dtype = waveform.dtype
output_waveform = lfilter(
waveform,
torch.tensor([a0, a1, a2], dtype=dtype, device=device),
torch.tensor([b0, b1, b2], dtype=dtype, device=device)
)
return output_waveform
def _dB2Linear(x: float) -> float:
return math.exp(x * math.log(10) / 20.0)
def highpass_biquad(
waveform: Tensor,
sample_rate: int,
cutoff_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
Args:
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns:
Tensor: Waveform dimension of `(..., time)`
"""
w0 = 2 * math.pi * cutoff_freq / sample_rate
alpha = math.sin(w0) / 2. / Q
b0 = (1 + math.cos(w0)) / 2
b1 = -1 - math.cos(w0)
b2 = b0
a0 = 1 + alpha
a1 = -2 * math.cos(w0)
a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def lowpass_biquad(
waveform: Tensor,
sample_rate: int,
cutoff_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns:
Tensor: Waveform of dimension of `(..., time)`
"""
w0 = 2 * math.pi * cutoff_freq / sample_rate
alpha = math.sin(w0) / 2 / Q
b0 = (1 - math.cos(w0)) / 2
b1 = 1 - math.cos(w0)
b2 = b0
a0 = 1 + alpha
a1 = -2 * math.cos(w0)
a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
n_fft=800, dtype=np.float32, norm=None):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
win_sq = librosa_util.pad_center(win_sq, n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x
class MelScale(torch.nn.Module):
r"""Turn a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args:
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self,
n_mels: int = 128,
sample_rate: int = 24000,
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: Optional[int] = None) -> None:
super(MelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
fb = torch.empty(0) if n_stft is None else create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb)
def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
Returns:
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
# pack batch
shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0:
tmp_fb = create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
return mel_specgram
class TorchSTFT(torch.nn.Module):
def __init__(self, fft_size, hop_size, win_size,
normalized=False, domain='linear',
mel_scale=False, ref_level_db=20, min_level_db=-100):
super().__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.win_size = win_size
self.ref_level_db = ref_level_db
self.min_level_db = min_level_db
self.window = torch.hann_window(win_size)
self.normalized = normalized
self.domain = domain
self.mel_scale = MelScale(n_mels=(fft_size // 2 + 1),
n_stft=(fft_size // 2 + 1)) if mel_scale else None
def transform(self, x):
x_stft = torch.stft(x.to(torch.float32), self.fft_size, self.hop_size, self.win_size,
self.window.type_as(x), normalized=self.normalized)
real = x_stft[..., 0]
imag = x_stft[..., 1]
mag = torch.clamp(real ** 2 + imag ** 2, min=1e-7)
mag = torch.sqrt(mag)
phase = torch.atan2(imag, real)
if self.mel_scale is not None:
mag = self.mel_scale(mag)
if self.domain == 'log':
mag = 20 * torch.log10(mag) - self.ref_level_db
mag = torch.clamp((mag - self.min_level_db) / -self.min_level_db, 0, 1)
return mag, phase
elif self.domain == 'linear':
return mag, phase
elif self.domain == 'double':
log_mag = 20 * torch.log10(mag) - self.ref_level_db
log_mag = torch.clamp((log_mag - self.min_level_db) / -self.min_level_db, 0, 1)
return torch.cat((mag, log_mag), dim=1), phase
def complex(self, x):
x_stft = torch.stft(x, self.fft_size, self.hop_size, self.win_size,
self.window.type_as(x), normalized=self.normalized)
real = x_stft[..., 0]
imag = x_stft[..., 1]
return real, imag
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length=800, hop_length=200, win_length=800,
window='hann'):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
assert(filter_length >= win_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer('forward_basis', forward_basis.float())
self.register_buffer('inverse_basis', inverse_basis.float())
def transform(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
input_data = F.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode='reflect')
input_data = input_data.squeeze(1)
forward_transform = F.conv1d(
input_data,
Variable(self.forward_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(
torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
def inverse(self, magnitude, phase):
recombine_magnitude_phase = torch.cat(
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
if self.window is not None:
window_sum = window_sumsquare(
self.window, magnitude.size(-1), hop_length=self.hop_length,
win_length=self.win_length, n_fft=self.filter_length,
dtype=np.float32)
# remove modulation effects
approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0])
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False)
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
return inverse_transform
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction

193
modules/transforms.py Normal file
View File

@ -0,0 +1,193 @@
import torch
from torch.nn import functional as F
import numpy as np
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {
'tails': tails,
'tail_bound': tail_bound
}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(
inputs[..., None] >= bin_locations,
dim=-1
) - 1
def unconstrained_rational_quadratic_spline(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails='linear',
tail_bound=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == 'linear':
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError('{} tails are not implemented.'.format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative
)
return outputs, logabsdet
def rational_quadratic_spline(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0., right=1., bottom=0., top=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError('Input to a transform is not within its domain')
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError('Minimal bin width too large for the number of bins')
if min_bin_height * num_bins > 1.0:
raise ValueError('Minimal bin height too large for the number of bins')
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (((inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta)
+ input_heights * (input_delta - input_derivatives)))
b = (input_heights * input_derivatives
- (inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta))
c = - input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2))
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2)
+ input_derivatives * theta_one_minus_theta)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2))
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

File diff suppressed because it is too large Load Diff