Updata Depthwise Separable Conv1D to Infer Speed Up

This commit is contained in:
YuriHead 2023-06-26 00:58:52 +08:00
parent 7c0d113eae
commit 3691bbf5f3
10 changed files with 249 additions and 122 deletions

View File

@ -60,7 +60,10 @@
"vocoder_name":"nsf-hifigan",
"speech_encoder":"vec768l12",
"speaker_embedding":false,
"vol_embedding":false
"vol_embedding":false,
"use_depthwise_conv":false,
"use_depthwise_transposeconv":false,
"use_automatic_f0_prediction": true
},
"spk": {
"nyaru": 0,

View File

@ -321,6 +321,9 @@ class SynthesizerTrn(nn.Module):
sampling_rate=44100,
vol_embedding=False,
vocoder_name = "nsf-hifigan",
use_depthwise_conv = False,
use_depthwise_transposeconv = False,
use_automatic_f0_prediction = True,
**kwargs):
super().__init__()
@ -343,6 +346,8 @@ class SynthesizerTrn(nn.Module):
self.ssl_dim = ssl_dim
self.vol_embedding = vol_embedding
self.emb_g = nn.Embedding(n_speakers, gin_channels)
self.use_depthwise_conv = use_depthwise_conv
self.use_automatic_f0_prediction = use_automatic_f0_prediction
if vol_embedding:
self.emb_vol = nn.Linear(1, hidden_channels)
@ -367,9 +372,12 @@ class SynthesizerTrn(nn.Module):
"upsample_initial_channel": upsample_initial_channel,
"upsample_kernel_sizes": upsample_kernel_sizes,
"gin_channels": gin_channels,
"use_depthwise_conv":use_depthwise_conv,
"use_depthwise_transposeconv":use_depthwise_transposeconv
}
modules.set_Conv1dModel(self.use_depthwise_conv)
if vocoder_name == "nsf-hifigan":
from vdecoder.hifigan.models import Generator
self.dec = Generator(h=hps)
@ -383,16 +391,17 @@ class SynthesizerTrn(nn.Module):
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
self.f0_decoder = F0Decoder(
1,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
spk_channels=gin_channels
)
if self.use_automatic_f0_prediction:
self.f0_decoder = F0Decoder(
1,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
spk_channels=gin_channels
)
self.emb_uv = nn.Embedding(2, hidden_channels)
self.character_mix = False
@ -412,12 +421,16 @@ class SynthesizerTrn(nn.Module):
# ssl prenet
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol
# f0 predict
lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
norm_lf0 = utils.normalize_f0(lf0, x_mask, uv)
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
if self.use_automatic_f0_prediction:
lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
norm_lf0 = utils.normalize_f0(lf0, x_mask, uv)
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
else:
lf0 = 0
norm_lf0 = 0
pred_lf0 = 0
# encoder
z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0))
z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
@ -431,6 +444,7 @@ class SynthesizerTrn(nn.Module):
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0
@torch.no_grad()
def infer(self, c, f0, uv, g=None, noice_scale=0.35, seed=52468, predict_f0=False, vol = None):
if c.device == torch.device("cuda"):
@ -453,10 +467,10 @@ class SynthesizerTrn(nn.Module):
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
# vol proj
vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol
if predict_f0:
if self.use_automatic_f0_prediction and predict_f0:
lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False)
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)

76
modules/DSConv.py Normal file
View File

@ -0,0 +1,76 @@
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm, remove_weight_norm
class Depthwise_Separable_Conv1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
bias = True,
padding_mode = 'zeros', # TODO: refine this type
device=None,
dtype=None
):
super().__init__()
self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
def forward(self, input):
return self.point_conv(self.depth_conv(input))
def weight_norm(self):
self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
self.point_conv = weight_norm(self.point_conv, name = 'weight')
def remove_weight_norm(self):
self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')
class Depthwise_Separable_TransposeConv1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride = 1,
padding = 0,
output_padding = 0,
bias = True,
dilation = 1,
padding_mode = 'zeros', # TODO: refine this type
device=None,
dtype=None
):
super().__init__()
self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
def forward(self, input):
return self.point_conv(self.depth_conv(input))
def weight_norm(self):
self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
self.point_conv = weight_norm(self.point_conv, name = 'weight')
def remove_weight_norm(self):
remove_weight_norm(self.depth_conv, name = 'weight')
remove_weight_norm(self.point_conv, name = 'weight')
def weight_norm_modules(module, name = 'weight', dim = 0):
if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
module.weight_norm()
return module
else:
return weight_norm(module,name,dim)
def remove_weight_norm_modules(module, name = 'weight'):
if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
module.remove_weight_norm()
else:
remove_weight_norm(module,name)

View File

@ -24,10 +24,12 @@ def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
if "Depthwise_Separable" in classname:
m.depth_conv.weight.data.normal_(mean, std)
m.point_conv.weight.data.normal_(mean, std)
elif classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)

View File

@ -1,20 +1,20 @@
import copy
import math
import numpy as np
import scipy
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm
from modules.DSConv import weight_norm_modules, remove_weight_norm_modules, Depthwise_Separable_Conv1D
import modules.commons as commons
from modules.commons import init_weights, get_padding
LRELU_SLOPE = 0.1
Conv1dModel = nn.Conv1d
def set_Conv1dModel(use_depthwise_conv):
global Conv1dModel
Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
@ -44,13 +44,13 @@ class ConvReluNorm(nn.Module):
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.conv_layers.append(Conv1dModel(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers-1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.conv_layers.append(Conv1dModel(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
@ -124,14 +124,14 @@ class WN(torch.nn.Module):
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
self.cond_layer = weight_norm_modules(cond_layer, name='weight')
for i in range(n_layers):
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
in_layer = Conv1dModel(hidden_channels, 2*hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
in_layer = weight_norm_modules(in_layer, name='weight')
self.in_layers.append(in_layer)
# last one is not necessary
@ -141,7 +141,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
res_skip_layer = weight_norm_modules(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
@ -176,32 +176,32 @@ class WN(torch.nn.Module):
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
remove_weight_norm_modules(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
remove_weight_norm_modules(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
remove_weight_norm_modules(l)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
@ -223,18 +223,18 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_weight_norm_modules(l)
for l in self.convs2:
remove_weight_norm(l)
remove_weight_norm_modules(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
@ -252,7 +252,7 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_weight_norm_modules(l)
class Log(nn.Module):

View File

@ -209,7 +209,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_lf0 = F.mse_loss(pred_lf0, lf0)
loss_lf0 = F.mse_loss(pred_lf0, lf0) if net_g.module.use_automatic_f0_prediction else 0
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
@ -241,13 +241,17 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
"all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(),
pred_lf0[0, 0, :].detach().cpu().numpy()),
"all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(),
norm_lf0[0, 0, :].detach().cpu().numpy())
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy())
}
if net_g.module.use_automatic_f0_prediction:
image_dict.module.update({
"all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(),
pred_lf0[0, 0, :].detach().cpu().numpy()),
"all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(),
norm_lf0[0, 0, :].detach().cpu().numpy())
})
utils.summarize(
writer=writer,
global_step=global_step,
@ -328,4 +332,4 @@ def evaluate(hps, generator, eval_loader, writer_eval):
if __name__ == "__main__":
main()
main()

View File

@ -6,11 +6,23 @@ import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from torch.nn.utils import weight_norm,spectral_norm
from .utils import init_weights, get_padding
from modules.DSConv import weight_norm_modules, remove_weight_norm_modules, Depthwise_Separable_Conv1D, Depthwise_Separable_TransposeConv1D
LRELU_SLOPE = 0.1
Conv1dModel = nn.Conv1d
ConvTranspose1dModel = nn.ConvTranspose1d
def set_Conv1dModel(use_depthwise_conv):
global Conv1dModel
Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d
def set_ConvTranspose1dModel(use_depthwise_transposeconv):
global ConvTranspose1dModel
ConvTranspose1dModel = Depthwise_Separable_TransposeConv1D if use_depthwise_transposeconv else nn.ConvTranspose1d
def load_model(model_path, device='cuda'):
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
@ -36,21 +48,21 @@ class ResBlock1(torch.nn.Module):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
@ -66,9 +78,9 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_weight_norm_modules(l)
for l in self.convs2:
remove_weight_norm(l)
remove_weight_norm_modules(l)
class ResBlock2(torch.nn.Module):
@ -76,9 +88,9 @@ class ResBlock2(torch.nn.Module):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
@ -92,7 +104,7 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_weight_norm_modules(l)
def padDiff(x):
@ -277,7 +289,10 @@ class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
set_Conv1dModel(h["use_depthwise_conv"])
set_ConvTranspose1dModel(h["use_depthwise_transposeconv"])
self.num_kernels = len(h["resblock_kernel_sizes"])
self.num_upsamples = len(h["upsample_rates"])
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"]))
@ -285,17 +300,17 @@ class Generator(torch.nn.Module):
sampling_rate=h["sampling_rate"],
harmonic_num=8)
self.noise_convs = nn.ModuleList()
self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3))
self.conv_pre = weight_norm_modules(Conv1dModel(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3))
resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])):
c_cur = h["upsample_initial_channel"] // (2 ** (i + 1))
self.ups.append(weight_norm(
ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)),
k, u, padding=(k - u +1 ) // 2)))
self.ups.append(weight_norm_modules(
ConvTranspose1dModel(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)),
k, u, padding=(k - u + 1 ) // 2)))
if i + 1 < len(h["upsample_rates"]): #
stride_f0 = np.prod(h["upsample_rates"][i + 1:])
self.noise_convs.append(Conv1d(
self.noise_convs.append(Conv1dModel(
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
else:
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
@ -305,7 +320,7 @@ class Generator(torch.nn.Module):
for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.conv_post = weight_norm_modules(Conv1dModel(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1)
@ -342,11 +357,11 @@ class Generator(torch.nn.Module):
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
remove_weight_norm_modules(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
remove_weight_norm_modules(self.conv_pre)
remove_weight_norm_modules(self.conv_post)
class DiscriminatorP(torch.nn.Module):

View File

@ -21,7 +21,10 @@ def plot_spectrogram(spectrogram):
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
if "Depthwise_Separable" in classname:
m.depth_conv.weight.data.normal_(mean, std)
m.point_conv.weight.data.normal_(mean, std)
elif classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)

View File

@ -6,12 +6,23 @@ import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from torch.nn.utils import weight_norm, spectral_norm
from .utils import init_weights, get_padding
from vdecoder.hifiganwithsnake.alias.act import SnakeAlias
from modules.DSConv import weight_norm_modules, remove_weight_norm_modules, Depthwise_Separable_Conv1D, Depthwise_Separable_TransposeConv1D
LRELU_SLOPE = 0.1
Conv1dModel = nn.Conv1d
ConvTranspose1dModel = nn.ConvTranspose1d
def set_Conv1dModel(use_depthwise_conv):
global Conv1dModel
Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d
def set_ConvTranspose1dModel(use_depthwise_transposeconv):
global ConvTranspose1dModel
ConvTranspose1dModel = Depthwise_Separable_TransposeConv1D if use_depthwise_transposeconv else nn.ConvTranspose1d
def load_model(model_path, device='cuda'):
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
@ -33,79 +44,77 @@ def load_model(model_path, device='cuda'):
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), C=None):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2)
self.activations = nn.ModuleList([
SnakeAlias(channels, C=C) for _ in range(self.num_layers)
SnakeAlias(channels) for _ in range(self.num_layers)
])
def forward(self, x, DIM=None):
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x, DIM)
xt = a1(x)
xt = c1(xt)
xt = a2(xt, DIM)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_weight_norm_modules(l)
for l in self.convs2:
remove_weight_norm(l)
remove_weight_norm_modules(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), C=None):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
self.num_layers = len(self.convs)
self.activations = nn.ModuleList([
SnakeAlias(channels, C=C) for _ in range(self.num_layers)
SnakeAlias(channels) for _ in range(self.num_layers)
])
def forward(self, x, DIM=None):
def forward(self, x):
for c,a in zip(self.convs, self.activations):
xt = a(x, DIM)
xt = a(x)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_weight_norm_modules(l)
def padDiff(x):
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
@ -289,7 +298,10 @@ class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
set_Conv1dModel(h["use_depthwise_conv"])
set_ConvTranspose1dModel(h["use_depthwise_transposeconv"])
self.num_kernels = len(h["resblock_kernel_sizes"])
self.num_upsamples = len(h["upsample_rates"])
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"]))
@ -297,32 +309,29 @@ class Generator(torch.nn.Module):
sampling_rate=h["sampling_rate"],
harmonic_num=8)
self.noise_convs = nn.ModuleList()
self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3))
self.conv_pre = weight_norm_modules(Conv1dModel(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3))
resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])):
c_cur = h["upsample_initial_channel"] // (2 ** (i + 1))
self.ups.append(weight_norm(
ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)),
k, u, padding=(k - u + 1) // 2)))
self.ups.append(weight_norm_modules(
ConvTranspose1dModel(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)),
k, u, padding=(k - u +1 ) // 2)))
if i + 1 < len(h["upsample_rates"]): #
stride_f0 = np.prod(h["upsample_rates"][i + 1:])
self.noise_convs.append(Conv1d(
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+ 1) // 2))
self.noise_convs.append(Conv1dModel(
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
else:
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
self.snakes = nn.ModuleList()
for i in range(len(self.ups)):
ch = h["upsample_initial_channel"] // (2 ** (i + 1))
self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)), C = h["upsample_initial_channel"] >> i))
for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])):
self.resblocks.append(resblock(h, ch, k, d, C = h["upsample_initial_channel"] >> (i + 1)))
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.conv_post = weight_norm_modules(Conv1dModel(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups))
self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1)
def forward(self, x, f0, g=None):
@ -335,9 +344,8 @@ class Generator(torch.nn.Module):
x = x + self.cond(g)
# print(124,x.shape,har_source.shape)
for i in range(self.num_upsamples):
# print(f"self.snakes.{i}.pre:", x.shape)
x = self.snakes[i](x)
# print(f"self.snakes.{i}.after:", x.shape)
# print(3,x.shape)
x = self.ups[i](x)
x_source = self.noise_convs[i](har_source)
# print(4,x_source.shape,har_source.shape,x.shape)
@ -348,7 +356,6 @@ class Generator(torch.nn.Module):
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
# print(f"self.resblocks.{i}.after:", xs.shape)
x = xs / self.num_kernels
x = self.snake_post(x)
x = self.conv_post(x)
@ -359,11 +366,11 @@ class Generator(torch.nn.Module):
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
remove_weight_norm_modules(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
l.remove_weight_norm_modules()
remove_weight_norm_modules(self.conv_pre)
remove_weight_norm_modules(self.conv_post)
class DiscriminatorP(torch.nn.Module):

View File

@ -21,7 +21,10 @@ def plot_spectrogram(spectrogram):
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
if "Depthwise_Separable" in classname:
m.depth_conv.weight.data.normal_(mean, std)
m.point_conv.weight.data.normal_(mean, std)
elif classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)