114 lines
4.3 KiB
Python
114 lines
4.3 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torchaudio
|
||
|
||
import numpy as np
|
||
from torchaudio.transforms import Spectrogram
|
||
|
||
# Multiple Rate Dilated Convolution
|
||
class MRDConv(nn.Module):
|
||
def __init__(self, in_channels, out_channels, dilation_list = [0, 12, 19, 24, 28, 31, 34, 36]):
|
||
super().__init__()
|
||
self.dilation_list = dilation_list
|
||
self.conv_list = []
|
||
for i in range(len(dilation_list)):
|
||
self.conv_list += [nn.Conv2d(in_channels, out_channels, kernel_size = [1, 1])]
|
||
self.conv_list = nn.ModuleList(self.conv_list)
|
||
|
||
def forward(self, specgram):
|
||
# input [b x C x T x n_freq]
|
||
# output: [b x C x T x n_freq]
|
||
specgram
|
||
dilation = self.dilation_list[0]
|
||
y = self.conv_list[0](specgram)
|
||
y = F.pad(y, pad=[0, dilation])
|
||
y = y[:, :, :, dilation:]
|
||
for i in range(1, len(self.conv_list)):
|
||
dilation = self.dilation_list[i]
|
||
x = self.conv_list[i](specgram)
|
||
# => [b x T x (n_freq + dilation)]
|
||
# x = F.pad(x, pad=[0, dilation])
|
||
x = x[:, :, :, dilation:]
|
||
n_freq = x.size()[3]
|
||
y[:, :, :, :n_freq] += x
|
||
|
||
return y
|
||
|
||
# Fixed Rate Dilated Casual Convolution
|
||
class FRDConv(nn.Module):
|
||
def __init__(self, in_channels, out_channels, kernel_size=[1,3], dilation=[1, 1]) -> None:
|
||
super().__init__()
|
||
right = (kernel_size[1]-1) * dilation[1]
|
||
bottom = (kernel_size[0]-1) * dilation[0]
|
||
self.padding = nn.ZeroPad2d([0, right, 0 , bottom])
|
||
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, dilation=dilation)
|
||
|
||
def forward(self,x):
|
||
x = self.padding(x)
|
||
x = self.conv2d(x)
|
||
return x
|
||
|
||
|
||
class WaveformToLogSpecgram(nn.Module):
|
||
def __init__(self, sample_rate, n_fft, fmin, bins_per_octave, freq_bins, hop_length, logspecgram_type): #, device
|
||
super().__init__()
|
||
|
||
e = freq_bins/bins_per_octave
|
||
fmax = fmin * (2 ** e)
|
||
|
||
self.logspecgram_type = logspecgram_type
|
||
self.n_fft = n_fft
|
||
hamming_window = torch.hann_window(self.n_fft)#.to(device)
|
||
# => [1 x 1 x n_fft]
|
||
hamming_window = hamming_window[None, None, :]
|
||
self.register_buffer("hamming_window", hamming_window, persistent=False)
|
||
|
||
# torch.hann_window()
|
||
|
||
fre_resolution = sample_rate/n_fft
|
||
|
||
idxs = torch.arange(0, freq_bins) #, device=device
|
||
|
||
log_idxs = fmin * (2**(idxs/bins_per_octave)) / fre_resolution
|
||
|
||
# Linear interpolation: y_k = y_i * (k-i) + y_{i+1} * ((i+1)-k)
|
||
log_idxs_floor = torch.floor(log_idxs).long()
|
||
log_idxs_floor_w = (log_idxs - log_idxs_floor).reshape([1, 1, freq_bins])
|
||
log_idxs_ceiling = torch.ceil(log_idxs).long()
|
||
log_idxs_ceiling_w = (log_idxs_ceiling - log_idxs).reshape([1, 1, freq_bins])
|
||
self.register_buffer("log_idxs_floor", log_idxs_floor, persistent=False)
|
||
self.register_buffer("log_idxs_floor_w", log_idxs_floor_w, persistent=False)
|
||
self.register_buffer("log_idxs_ceiling", log_idxs_ceiling, persistent=False)
|
||
self.register_buffer("log_idxs_ceiling_w", log_idxs_ceiling_w, persistent=False)
|
||
|
||
self.waveform_to_specgram = torchaudio.transforms.Spectrogram(n_fft, hop_length=hop_length)#.to(device)
|
||
|
||
assert(bins_per_octave % 12 == 0)
|
||
bins_per_semitone = bins_per_octave // 12
|
||
|
||
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB(top_db=80)
|
||
|
||
def forward(self, waveforms):
|
||
# inputs: [b x num_frames x frame_len]
|
||
# outputs: [b x num_frames x n_bins]
|
||
|
||
if(self.logspecgram_type == 'logharmgram'):
|
||
waveforms = waveforms * self.hamming_window
|
||
specgram = torch.fft.fft(waveforms)
|
||
specgram = torch.abs(specgram[:, :, :self.n_fft//2 + 1])
|
||
specgram = specgram * specgram
|
||
# => [num_frames x n_fft//2 x 1]
|
||
# specgram = torch.unsqueeze(specgram, dim=2)
|
||
|
||
# => [b x freq_bins x T]
|
||
specgram = specgram[:,:, self.log_idxs_floor] * self.log_idxs_floor_w + specgram[:, :, self.log_idxs_ceiling] * self.log_idxs_ceiling_w
|
||
|
||
specgram_db = self.amplitude_to_db(specgram)
|
||
# specgram_db = specgram_db[:, :, :-1] # remove the last frame.
|
||
# specgram_db = specgram_db.permute([0, 2, 1])
|
||
return specgram_db
|
||
|
||
|
||
|