From 74c5505b3ac12c370257e282558a98dd737ec1af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E5=8F=B6=20=E8=97=A4=E5=8E=9F?= <1751842477@qq.com> Date: Sun, 18 Jun 2023 23:32:52 +0800 Subject: [PATCH] =?UTF-8?q?snake=20Onnx=20=E5=AF=BC=E5=87=BA=E6=94=AF?= =?UTF-8?q?=E6=8C=81=EF=BC=8C=E6=96=B0=E7=89=88Onnx=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- models.py | 4 +- onnx_export_speaker_mix.py | 181 ++++++++++++++------ onnxexport/model_onnx_speaker_mix.py | 166 +++++------------- vdecoder/hifiganwithsnake/alias/act.py | 11 +- vdecoder/hifiganwithsnake/alias/filter.py | 26 ++- vdecoder/hifiganwithsnake/alias/resample.py | 43 +++-- vdecoder/hifiganwithsnake/models.py | 28 +-- 8 files changed, 243 insertions(+), 218 deletions(-) diff --git a/.gitignore b/.gitignore index b48b137..bb95e1f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ __pycache__/ # C extensions *.so - +checkpoints/ # Distribution / packaging .Python build/ diff --git a/models.py b/models.py index ac40c3c..1f67b29 100644 --- a/models.py +++ b/models.py @@ -453,8 +453,8 @@ class SynthesizerTrn(nn.Module): x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) # vol proj vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0 - - x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol + + x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol if predict_f0: lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 diff --git a/onnx_export_speaker_mix.py b/onnx_export_speaker_mix.py index b137169..cb80735 100644 --- a/onnx_export_speaker_mix.py +++ b/onnx_export_speaker_mix.py @@ -1,67 +1,136 @@ import torch from onnxexport.model_onnx_speaker_mix import SynthesizerTrn import utils +import json -def main(HubertExport, NetExport): - path = "SummerPockets" - if NetExport: - device = torch.device("cpu") - hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") - SVCVITS = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - **hps.model) - _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) - _ = SVCVITS.eval().to(device) - for i in SVCVITS.parameters(): - i.requires_grad = False - test_hidden_unit = torch.rand(1, 10, SVCVITS.gin_channels) - test_pitch = torch.rand(1, 10) - test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0) - test_uv = torch.ones(1, 10, dtype=torch.float32) - test_noise = torch.randn(1, 192, 10) +def main(): + path = "crs" - export_mix = True + device = torch.device("cpu") + hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") + SVCVITS = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model) + _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) + _ = SVCVITS.eval().to(device) + for i in SVCVITS.parameters(): + i.requires_grad = False + + num_frames = 200 - test_sid = torch.LongTensor([0]) + test_hidden_unit = torch.rand(1, num_frames, SVCVITS.gin_channels) + test_pitch = torch.rand(1, num_frames) + test_vol = torch.rand(1, num_frames) + test_mel2ph = torch.LongTensor(torch.arange(0, num_frames)).unsqueeze(0) + test_uv = torch.ones(1, num_frames, dtype=torch.float32) + test_noise = torch.randn(1, 192, num_frames) + test_sid = torch.LongTensor([0]) + export_mix = True + if len(hps.spk) < 2: + export_mix = False + + if export_mix: spk_mix = [] - if export_mix: - n_spk = len(hps.spk) - for i in range(n_spk): - spk_mix.append(1.0/float(n_spk)) - test_sid = torch.tensor(spk_mix) - SVCVITS.export_chara_mix(n_spk) - test_sid = test_sid.unsqueeze(0) - test_sid = test_sid.repeat(10, 1) - - input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] - output_names = ["audio", ] - SVCVITS.eval() + n_spk = len(hps.spk) + for i in range(n_spk): + spk_mix.append(1.0/float(n_spk)) + test_sid = torch.tensor(spk_mix) + SVCVITS.export_chara_mix(hps.spk) + test_sid = test_sid.unsqueeze(0) + test_sid = test_sid.repeat(num_frames, 1) + + SVCVITS.eval() - torch.onnx.export(SVCVITS, - ( - test_hidden_unit.to(device), - test_pitch.to(device), - test_mel2ph.to(device), - test_uv.to(device), - test_noise.to(device), - test_sid.to(device) - ), - f"checkpoints/{path}/model.onnx", - dynamic_axes={ - "c": [0, 1], - "f0": [1], - "mel2ph": [1], - "uv": [1], - "noise": [2], - "sid":[0] - }, - do_constant_folding=False, - opset_version=16, - verbose=False, - input_names=input_names, - output_names=output_names) + if export_mix: + daxes = { + "c": [0, 1], + "f0": [1], + "mel2ph": [1], + "uv": [1], + "noise": [2], + "sid":[0] + } + else: + daxes = { + "c": [0, 1], + "f0": [1], + "mel2ph": [1], + "uv": [1], + "noise": [2] + } + + input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] + output_names = ["audio", ] + + if SVCVITS.vol_embedding: + input_names.append("vol") + vol_dadict = {"vol" : [1]} + daxes.update(vol_dadict) + test_inputs = ( + test_hidden_unit.to(device), + test_pitch.to(device), + test_mel2ph.to(device), + test_uv.to(device), + test_noise.to(device), + test_sid.to(device), + test_vol.to(device) + ) + else: + test_inputs = ( + test_hidden_unit.to(device), + test_pitch.to(device), + test_mel2ph.to(device), + test_uv.to(device), + test_noise.to(device), + test_sid.to(device) + ) + + # SVCVITS = torch.jit.script(SVCVITS) + SVCVITS(test_hidden_unit.to(device), + test_pitch.to(device), + test_mel2ph.to(device), + test_uv.to(device), + test_noise.to(device), + test_sid.to(device), + test_vol.to(device)) + + torch.onnx.export( + SVCVITS, + test_inputs, + f"checkpoints/{path}/{path}_SoVits.onnx", + dynamic_axes=daxes, + do_constant_folding=False, + opset_version=16, + verbose=False, + input_names=input_names, + output_names=output_names + ) + + vec_lay = "layer-12" if SVCVITS.gin_channels == 768 else "layer-9" + spklist = [] + for key in hps.spk.keys(): + spklist.append(key) + + MoeVSConf = { + "Folder" : f"{path}", + "Name" : f"{path}", + "Type" : "SoVits", + "Rate" : hps.data.sampling_rate, + "Hop" : hps.data.hop_length, + "Hubert": f"vec-{SVCVITS.gin_channels}-{vec_lay}", + "SoVits4": True, + "SoVits3": False, + "CharaMix": export_mix, + "Volume": SVCVITS.vol_embedding, + "HiddenSize": SVCVITS.gin_channels, + "Characters": spklist + } + + MoeVSConfJson = json.dumps(MoeVSConf) + with open(f"checkpoints/{path}.json", 'w') as MoeVsConfFile: + json.dump(MoeVSConf, MoeVsConfFile, indent = 4) if __name__ == '__main__': - main(False, True) + main() diff --git a/onnxexport/model_onnx_speaker_mix.py b/onnxexport/model_onnx_speaker_mix.py index 355e590..f819363 100644 --- a/onnxexport/model_onnx_speaker_mix.py +++ b/onnxexport/model_onnx_speaker_mix.py @@ -1,6 +1,9 @@ +import copy +import math import torch from torch import nn from torch.nn import functional as F + import modules.attentions as attentions import modules.commons as commons import modules.modules as modules @@ -10,10 +13,8 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm import utils from modules.commons import init_weights, get_padding -from vdecoder.hifigan.models import Generator from utils import f0_to_coarse - class ResidualCouplingBlock(nn.Module): def __init__(self, channels, @@ -49,39 +50,6 @@ class ResidualCouplingBlock(nn.Module): return x -class Encoder(nn.Module): - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths, g=None): - # print(x.shape,x_lengths.shape) - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - class TextEncoder(nn.Module): def __init__(self, out_channels, @@ -115,74 +83,10 @@ class TextEncoder(nn.Module): stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) z = (m + z * torch.exp(logs)) * x_mask + return z, m, logs, x_mask -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), - ]) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - class F0Decoder(nn.Module): def __init__(self, out_channels, @@ -217,7 +121,7 @@ class F0Decoder(nn.Module): def forward(self, x, norm_f0, x_mask, spk_emb=None): x = torch.detach(x) - if spk_emb is not None: + if (spk_emb is not None): x = x + self.cond(spk_emb) x += self.f0_prenet(norm_f0) x = self.prenet(x) * x_mask @@ -228,8 +132,8 @@ class F0Decoder(nn.Module): class SynthesizerTrn(nn.Module): """ - Synthesizer for Training - """ + Synthesizer for Training + """ def __init__(self, spec_channels, @@ -251,7 +155,10 @@ class SynthesizerTrn(nn.Module): ssl_dim, n_speakers, sampling_rate=44100, + vol_embedding=False, + vocoder_name = "nsf-hifigan", **kwargs): + super().__init__() self.spec_channels = spec_channels self.inter_channels = inter_channels @@ -270,7 +177,10 @@ class SynthesizerTrn(nn.Module): self.segment_size = segment_size self.gin_channels = gin_channels self.ssl_dim = ssl_dim + self.vol_embedding = vol_embedding self.emb_g = nn.Embedding(n_speakers, gin_channels) + if vol_embedding: + self.emb_vol = nn.Linear(1, hidden_channels) self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) @@ -294,8 +204,18 @@ class SynthesizerTrn(nn.Module): "upsample_kernel_sizes": upsample_kernel_sizes, "gin_channels": gin_channels, } - self.dec = Generator(h=hps) - self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + + if vocoder_name == "nsf-hifigan": + from vdecoder.hifigan.models import Generator + self.dec = Generator(h=hps) + elif vocoder_name == "nsf-snake-hifigan": + from vdecoder.hifiganwithsnake.models import Generator + self.dec = Generator(h=hps) + else: + print("[?] Unkown vocoder: use default(nsf-hifigan)") + from vdecoder.hifigan.models import Generator + self.dec = Generator(h=hps) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) self.f0_decoder = F0Decoder( 1, @@ -312,39 +232,39 @@ class SynthesizerTrn(nn.Module): self.speaker_map = [] self.export_mix = False - def export_chara_mix(self, n_speakers_mix): - self.speaker_map = torch.zeros((n_speakers_mix, 1, 1, self.gin_channels)) - for i in range(n_speakers_mix): - self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]])) + def export_chara_mix(self, speakers_mix): + self.speaker_map = torch.zeros((len(speakers_mix), 1, 1, self.gin_channels)) + i = 0 + for key in speakers_mix.keys(): + spkidx = speakers_mix[key] + self.speaker_map[i] = self.emb_g(torch.LongTensor([[spkidx]])) + i = i + 1 self.speaker_map = self.speaker_map.unsqueeze(0) self.export_mix = True - def forward(self, c, f0, mel2ph, uv, noise=None, g=None, cluster_infer_ratio=0.1): + def forward(self, c, f0, mel2ph, uv, noise=None, g=None, vol = None): decoder_inp = F.pad(c, [0, 0, 1, 0]) mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]]) c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H] - c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) - if self.export_mix: # [N, S] * [S, B, 1, H] g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1] g = g * self.speaker_map # [N, S, B, 1, H] g = torch.sum(g, dim=1) # [N, 1, B, 1, H] g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N] else: - g = g.unsqueeze(0) + if g.dim() == 1: + g = g.unsqueeze(0) g = self.emb_g(g).transpose(1, 2) - - x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) - x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) - - if self.predict_f0: - lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500 - norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False) - pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) - f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1) - + + x_mask = torch.unsqueeze(torch.ones_like(f0), 1).to(c.dtype) + # vol proj + vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0 + + x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol + z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), z=noise) z = self.flow(z_p, c_mask, g=g, reverse=True) o = self.dec(z * c_mask, g=g, f0=f0) return o + diff --git a/vdecoder/hifiganwithsnake/alias/act.py b/vdecoder/hifiganwithsnake/alias/act.py index 308344f..1465d1c 100644 --- a/vdecoder/hifiganwithsnake/alias/act.py +++ b/vdecoder/hifiganwithsnake/alias/act.py @@ -112,17 +112,18 @@ class SnakeAlias(nn.Module): up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, - down_kernel_size: int = 12): + down_kernel_size: int = 12, + C = None): super().__init__() self.up_ratio = up_ratio self.down_ratio = down_ratio self.act = SnakeBeta(channels, alpha_logscale=True) - self.upsample = UpSample1d(up_ratio, up_kernel_size) - self.downsample = DownSample1d(down_ratio, down_kernel_size) + self.upsample = UpSample1d(up_ratio, up_kernel_size, C) + self.downsample = DownSample1d(down_ratio, down_kernel_size, C) # x: [B,C,T] - def forward(self, x): - x = self.upsample(x) + def forward(self, x, C=None): + x = self.upsample(x, C) x = self.act(x) x = self.downsample(x) diff --git a/vdecoder/hifiganwithsnake/alias/filter.py b/vdecoder/hifiganwithsnake/alias/filter.py index 7ad6ea8..d2ccf1a 100644 --- a/vdecoder/hifiganwithsnake/alias/filter.py +++ b/vdecoder/hifiganwithsnake/alias/filter.py @@ -64,7 +64,8 @@ class LowPassFilter1d(nn.Module): stride: int = 1, padding: bool = True, padding_mode: str = 'replicate', - kernel_size: int = 12): + kernel_size: int = 12, + C=None): # kernel_size should be even number for stylegan3 setup, # in this implementation, odd number is also possible. super().__init__() @@ -81,15 +82,26 @@ class LowPassFilter1d(nn.Module): self.padding_mode = padding_mode filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) self.register_buffer("filter", filter) + self.conv1d_block = None + if C is not None: + self.conv1d_block = (nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False), 1) + self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1)) + self.conv1d_block[0].requires_grad_(False) #input [B, C, T] def forward(self, x): - _, C, _ = x.shape + if self.conv1d_block is None: + _, C, _ = x.shape - if self.padding: - x = F.pad(x, (self.pad_left, self.pad_right), - mode=self.padding_mode) - out = F.conv1d(x, self.filter.expand(C, -1, -1), - stride=self.stride, groups=C) + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + else: + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = self.conv1d_block[0](x) return out \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/alias/resample.py b/vdecoder/hifiganwithsnake/alias/resample.py index 750e6c3..53773c7 100644 --- a/vdecoder/hifiganwithsnake/alias/resample.py +++ b/vdecoder/hifiganwithsnake/alias/resample.py @@ -8,7 +8,7 @@ from .filter import kaiser_sinc_filter1d class UpSample1d(nn.Module): - def __init__(self, ratio=2, kernel_size=None): + def __init__(self, ratio=2, kernel_size=None, C=None): super().__init__() self.ratio = ratio self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size @@ -20,28 +20,49 @@ class UpSample1d(nn.Module): half_width=0.6 / ratio, kernel_size=self.kernel_size) self.register_buffer("filter", filter) + self.conv_transpose1d_block = None + if C is not None: + self.conv_transpose1d_block = (nn.ConvTranspose1d(C, + C, + kernel_size=self.kernel_size, + stride=self.stride, + groups=C, + bias=False + ), 1) + self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone()) + self.conv_transpose1d_block[0].requires_grad_(False) + + # x: [B, C, T] - def forward(self, x): - _, C, _ = x.shape - - x = F.pad(x, (self.pad, self.pad), mode='replicate') - x = self.ratio * F.conv_transpose1d( - x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) - x = x[..., self.pad_left:-self.pad_right] - + def forward(self, x, C=None): + if self.conv_transpose1d_block is None: + if C is None: + _, C, _ = x.shape + # print("snake.conv_t.in:",x.shape) + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + # print("snake.conv_t.out:",x.shape) + x = x[..., self.pad_left:-self.pad_right] + else: + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * self.conv_transpose1d_block[0](x) + x = x[..., self.pad_left:-self.pad_right] return x class DownSample1d(nn.Module): - def __init__(self, ratio=2, kernel_size=None): + def __init__(self, ratio=2, kernel_size=None, C=None): super().__init__() self.ratio = ratio self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, - kernel_size=self.kernel_size) + kernel_size=self.kernel_size, + C=C) + def forward(self, x): xx = self.lowpass(x) diff --git a/vdecoder/hifiganwithsnake/models.py b/vdecoder/hifiganwithsnake/models.py index 64f0e4d..4d9ae7a 100644 --- a/vdecoder/hifiganwithsnake/models.py +++ b/vdecoder/hifiganwithsnake/models.py @@ -33,7 +33,7 @@ def load_model(model_path, device='cuda'): class ResBlock1(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), C=None): super(ResBlock1, self).__init__() self.h = h self.convs1 = nn.ModuleList([ @@ -58,15 +58,15 @@ class ResBlock1(torch.nn.Module): self.num_layers = len(self.convs1) + len(self.convs2) self.activations = nn.ModuleList([ - SnakeAlias(channels) for _ in range(self.num_layers) + SnakeAlias(channels, C=C) for _ in range(self.num_layers) ]) - def forward(self, x): + def forward(self, x, DIM=None): acts1, acts2 = self.activations[::2], self.activations[1::2] for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): - xt = a1(x) + xt = a1(x, DIM) xt = c1(xt) - xt = a2(xt) + xt = a2(xt, DIM) xt = c2(xt) x = xt + x return x @@ -79,7 +79,7 @@ class ResBlock1(torch.nn.Module): class ResBlock2(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), C=None): super(ResBlock2, self).__init__() self.h = h self.convs = nn.ModuleList([ @@ -92,12 +92,12 @@ class ResBlock2(torch.nn.Module): self.num_layers = len(self.convs) self.activations = nn.ModuleList([ - SnakeAlias(channels) for _ in range(self.num_layers) + SnakeAlias(channels, C=C) for _ in range(self.num_layers) ]) - def forward(self, x): + def forward(self, x, DIM=None): for c,a in zip(self.convs, self.activations): - xt = a(x) + xt = a(x, DIM) xt = c(xt) x = xt + x return x @@ -315,14 +315,14 @@ class Generator(torch.nn.Module): self.snakes = nn.ModuleList() for i in range(len(self.ups)): ch = h["upsample_initial_channel"] // (2 ** (i + 1)) - self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)))) + self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)), C = h["upsample_initial_channel"] >> i)) for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): - self.resblocks.append(resblock(h, ch, k, d)) + self.resblocks.append(resblock(h, ch, k, d, C = h["upsample_initial_channel"] >> (i + 1))) self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) - self.snake_post = SnakeAlias(ch) + self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups)) self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) def forward(self, x, f0, g=None): @@ -335,8 +335,9 @@ class Generator(torch.nn.Module): x = x + self.cond(g) # print(124,x.shape,har_source.shape) for i in range(self.num_upsamples): + # print(f"self.snakes.{i}.pre:", x.shape) x = self.snakes[i](x) - # print(3,x.shape) + # print(f"self.snakes.{i}.after:", x.shape) x = self.ups[i](x) x_source = self.noise_convs[i](har_source) # print(4,x_source.shape,har_source.shape,x.shape) @@ -347,6 +348,7 @@ class Generator(torch.nn.Module): xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) + # print(f"self.resblocks.{i}.after:", xs.shape) x = xs / self.num_kernels x = self.snake_post(x) x = self.conv_post(x)