76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
import torch.nn as nn
|
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
|
|
|
|
|
class Depthwise_Separable_Conv1D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride = 1,
|
|
padding = 0,
|
|
dilation = 1,
|
|
bias = True,
|
|
padding_mode = 'zeros', # TODO: refine this type
|
|
device=None,
|
|
dtype=None
|
|
):
|
|
super().__init__()
|
|
self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
|
|
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
|
|
|
|
def forward(self, input):
|
|
return self.point_conv(self.depth_conv(input))
|
|
|
|
def weight_norm(self):
|
|
self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
|
|
self.point_conv = weight_norm(self.point_conv, name = 'weight')
|
|
|
|
def remove_weight_norm(self):
|
|
self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
|
|
self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')
|
|
|
|
class Depthwise_Separable_TransposeConv1D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride = 1,
|
|
padding = 0,
|
|
output_padding = 0,
|
|
bias = True,
|
|
dilation = 1,
|
|
padding_mode = 'zeros', # TODO: refine this type
|
|
device=None,
|
|
dtype=None
|
|
):
|
|
super().__init__()
|
|
self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
|
|
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
|
|
|
|
def forward(self, input):
|
|
return self.point_conv(self.depth_conv(input))
|
|
|
|
def weight_norm(self):
|
|
self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
|
|
self.point_conv = weight_norm(self.point_conv, name = 'weight')
|
|
|
|
def remove_weight_norm(self):
|
|
remove_weight_norm(self.depth_conv, name = 'weight')
|
|
remove_weight_norm(self.point_conv, name = 'weight')
|
|
|
|
|
|
def weight_norm_modules(module, name = 'weight', dim = 0):
|
|
if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
|
|
module.weight_norm()
|
|
return module
|
|
else:
|
|
return weight_norm(module,name,dim)
|
|
|
|
def remove_weight_norm_modules(module, name = 'weight'):
|
|
if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
|
|
module.remove_weight_norm()
|
|
else:
|
|
remove_weight_norm(module,name) |