Merge pull request #245 from svc-develop-team/4.1-Stable

snake Onnx 导出支持,新版Onnx支持
This commit is contained in:
YuriHead 2023-06-20 00:52:39 +08:00 committed by GitHub
commit a36805174b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 243 additions and 218 deletions

2
.gitignore vendored
View File

@ -10,7 +10,7 @@ __pycache__/
# C extensions
*.so
checkpoints/
# Distribution / packaging
.Python
build/

View File

@ -453,8 +453,8 @@ class SynthesizerTrn(nn.Module):
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
# vol proj
vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2) + vol
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol
if predict_f0:
lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500

View File

@ -1,67 +1,136 @@
import torch
from onnxexport.model_onnx_speaker_mix import SynthesizerTrn
import utils
import json
def main(HubertExport, NetExport):
path = "SummerPockets"
if NetExport:
device = torch.device("cpu")
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
SVCVITS = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
_ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
_ = SVCVITS.eval().to(device)
for i in SVCVITS.parameters():
i.requires_grad = False
test_hidden_unit = torch.rand(1, 10, SVCVITS.gin_channels)
test_pitch = torch.rand(1, 10)
test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
test_uv = torch.ones(1, 10, dtype=torch.float32)
test_noise = torch.randn(1, 192, 10)
def main():
path = "crs"
export_mix = True
device = torch.device("cpu")
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
SVCVITS = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
_ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
_ = SVCVITS.eval().to(device)
for i in SVCVITS.parameters():
i.requires_grad = False
num_frames = 200
test_sid = torch.LongTensor([0])
test_hidden_unit = torch.rand(1, num_frames, SVCVITS.gin_channels)
test_pitch = torch.rand(1, num_frames)
test_vol = torch.rand(1, num_frames)
test_mel2ph = torch.LongTensor(torch.arange(0, num_frames)).unsqueeze(0)
test_uv = torch.ones(1, num_frames, dtype=torch.float32)
test_noise = torch.randn(1, 192, num_frames)
test_sid = torch.LongTensor([0])
export_mix = True
if len(hps.spk) < 2:
export_mix = False
if export_mix:
spk_mix = []
if export_mix:
n_spk = len(hps.spk)
for i in range(n_spk):
spk_mix.append(1.0/float(n_spk))
test_sid = torch.tensor(spk_mix)
SVCVITS.export_chara_mix(n_spk)
test_sid = test_sid.unsqueeze(0)
test_sid = test_sid.repeat(10, 1)
input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
output_names = ["audio", ]
SVCVITS.eval()
n_spk = len(hps.spk)
for i in range(n_spk):
spk_mix.append(1.0/float(n_spk))
test_sid = torch.tensor(spk_mix)
SVCVITS.export_chara_mix(hps.spk)
test_sid = test_sid.unsqueeze(0)
test_sid = test_sid.repeat(num_frames, 1)
SVCVITS.eval()
torch.onnx.export(SVCVITS,
(
test_hidden_unit.to(device),
test_pitch.to(device),
test_mel2ph.to(device),
test_uv.to(device),
test_noise.to(device),
test_sid.to(device)
),
f"checkpoints/{path}/model.onnx",
dynamic_axes={
"c": [0, 1],
"f0": [1],
"mel2ph": [1],
"uv": [1],
"noise": [2],
"sid":[0]
},
do_constant_folding=False,
opset_version=16,
verbose=False,
input_names=input_names,
output_names=output_names)
if export_mix:
daxes = {
"c": [0, 1],
"f0": [1],
"mel2ph": [1],
"uv": [1],
"noise": [2],
"sid":[0]
}
else:
daxes = {
"c": [0, 1],
"f0": [1],
"mel2ph": [1],
"uv": [1],
"noise": [2]
}
input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
output_names = ["audio", ]
if SVCVITS.vol_embedding:
input_names.append("vol")
vol_dadict = {"vol" : [1]}
daxes.update(vol_dadict)
test_inputs = (
test_hidden_unit.to(device),
test_pitch.to(device),
test_mel2ph.to(device),
test_uv.to(device),
test_noise.to(device),
test_sid.to(device),
test_vol.to(device)
)
else:
test_inputs = (
test_hidden_unit.to(device),
test_pitch.to(device),
test_mel2ph.to(device),
test_uv.to(device),
test_noise.to(device),
test_sid.to(device)
)
# SVCVITS = torch.jit.script(SVCVITS)
SVCVITS(test_hidden_unit.to(device),
test_pitch.to(device),
test_mel2ph.to(device),
test_uv.to(device),
test_noise.to(device),
test_sid.to(device),
test_vol.to(device))
torch.onnx.export(
SVCVITS,
test_inputs,
f"checkpoints/{path}/{path}_SoVits.onnx",
dynamic_axes=daxes,
do_constant_folding=False,
opset_version=16,
verbose=False,
input_names=input_names,
output_names=output_names
)
vec_lay = "layer-12" if SVCVITS.gin_channels == 768 else "layer-9"
spklist = []
for key in hps.spk.keys():
spklist.append(key)
MoeVSConf = {
"Folder" : f"{path}",
"Name" : f"{path}",
"Type" : "SoVits",
"Rate" : hps.data.sampling_rate,
"Hop" : hps.data.hop_length,
"Hubert": f"vec-{SVCVITS.gin_channels}-{vec_lay}",
"SoVits4": True,
"SoVits3": False,
"CharaMix": export_mix,
"Volume": SVCVITS.vol_embedding,
"HiddenSize": SVCVITS.gin_channels,
"Characters": spklist
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"checkpoints/{path}.json", 'w') as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
if __name__ == '__main__':
main(False, True)
main()

View File

@ -1,6 +1,9 @@
import copy
import math
import torch
from torch import nn
from torch.nn import functional as F
import modules.attentions as attentions
import modules.commons as commons
import modules.modules as modules
@ -10,10 +13,8 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
import utils
from modules.commons import init_weights, get_padding
from vdecoder.hifigan.models import Generator
from utils import f0_to_coarse
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
@ -49,39 +50,6 @@ class ResidualCouplingBlock(nn.Module):
return x
class Encoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
# print(x.shape,x_lengths.shape)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class TextEncoder(nn.Module):
def __init__(self,
out_channels,
@ -115,74 +83,10 @@ class TextEncoder(nn.Module):
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + z * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class F0Decoder(nn.Module):
def __init__(self,
out_channels,
@ -217,7 +121,7 @@ class F0Decoder(nn.Module):
def forward(self, x, norm_f0, x_mask, spk_emb=None):
x = torch.detach(x)
if spk_emb is not None:
if (spk_emb is not None):
x = x + self.cond(spk_emb)
x += self.f0_prenet(norm_f0)
x = self.prenet(x) * x_mask
@ -228,8 +132,8 @@ class F0Decoder(nn.Module):
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
Synthesizer for Training
"""
def __init__(self,
spec_channels,
@ -251,7 +155,10 @@ class SynthesizerTrn(nn.Module):
ssl_dim,
n_speakers,
sampling_rate=44100,
vol_embedding=False,
vocoder_name = "nsf-hifigan",
**kwargs):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@ -270,7 +177,10 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size
self.gin_channels = gin_channels
self.ssl_dim = ssl_dim
self.vol_embedding = vol_embedding
self.emb_g = nn.Embedding(n_speakers, gin_channels)
if vol_embedding:
self.emb_vol = nn.Linear(1, hidden_channels)
self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
@ -294,8 +204,18 @@ class SynthesizerTrn(nn.Module):
"upsample_kernel_sizes": upsample_kernel_sizes,
"gin_channels": gin_channels,
}
self.dec = Generator(h=hps)
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
if vocoder_name == "nsf-hifigan":
from vdecoder.hifigan.models import Generator
self.dec = Generator(h=hps)
elif vocoder_name == "nsf-snake-hifigan":
from vdecoder.hifiganwithsnake.models import Generator
self.dec = Generator(h=hps)
else:
print("[?] Unkown vocoder: use default(nsf-hifigan)")
from vdecoder.hifigan.models import Generator
self.dec = Generator(h=hps)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
self.f0_decoder = F0Decoder(
1,
@ -312,39 +232,39 @@ class SynthesizerTrn(nn.Module):
self.speaker_map = []
self.export_mix = False
def export_chara_mix(self, n_speakers_mix):
self.speaker_map = torch.zeros((n_speakers_mix, 1, 1, self.gin_channels))
for i in range(n_speakers_mix):
self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
def export_chara_mix(self, speakers_mix):
self.speaker_map = torch.zeros((len(speakers_mix), 1, 1, self.gin_channels))
i = 0
for key in speakers_mix.keys():
spkidx = speakers_mix[key]
self.speaker_map[i] = self.emb_g(torch.LongTensor([[spkidx]]))
i = i + 1
self.speaker_map = self.speaker_map.unsqueeze(0)
self.export_mix = True
def forward(self, c, f0, mel2ph, uv, noise=None, g=None, cluster_infer_ratio=0.1):
def forward(self, c, f0, mel2ph, uv, noise=None, g=None, vol = None):
decoder_inp = F.pad(c, [0, 0, 1, 0])
mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]])
c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H]
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
if self.export_mix: # [N, S] * [S, B, 1, H]
g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
g = g * self.speaker_map # [N, S, B, 1, H]
g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
else:
g = g.unsqueeze(0)
if g.dim() == 1:
g = g.unsqueeze(0)
g = self.emb_g(g).transpose(1, 2)
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
if self.predict_f0:
lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False)
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
x_mask = torch.unsqueeze(torch.ones_like(f0), 1).to(c.dtype)
# vol proj
vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol!=None and self.vol_embedding else 0
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol
z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), z=noise)
z = self.flow(z_p, c_mask, g=g, reverse=True)
o = self.dec(z * c_mask, g=g, f0=f0)
return o

View File

@ -112,17 +112,18 @@ class SnakeAlias(nn.Module):
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
down_kernel_size: int = 12,
C = None):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = SnakeBeta(channels, alpha_logscale=True)
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
self.upsample = UpSample1d(up_ratio, up_kernel_size, C)
self.downsample = DownSample1d(down_ratio, down_kernel_size, C)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
def forward(self, x, C=None):
x = self.upsample(x, C)
x = self.act(x)
x = self.downsample(x)

View File

@ -64,7 +64,8 @@ class LowPassFilter1d(nn.Module):
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
kernel_size: int = 12,
C=None):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
@ -81,15 +82,26 @@ class LowPassFilter1d(nn.Module):
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
self.conv1d_block = None
if C is not None:
self.conv1d_block = (nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False), 1)
self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1))
self.conv1d_block[0].requires_grad_(False)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.conv1d_block is None:
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1),
stride=self.stride, groups=C)
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1),
stride=self.stride, groups=C)
else:
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = self.conv1d_block[0](x)
return out

View File

@ -8,7 +8,7 @@ from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
def __init__(self, ratio=2, kernel_size=None, C=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
@ -20,28 +20,49 @@ class UpSample1d(nn.Module):
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
self.conv_transpose1d_block = None
if C is not None:
self.conv_transpose1d_block = (nn.ConvTranspose1d(C,
C,
kernel_size=self.kernel_size,
stride=self.stride,
groups=C,
bias=False
), 1)
self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone())
self.conv_transpose1d_block[0].requires_grad_(False)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
def forward(self, x, C=None):
if self.conv_transpose1d_block is None:
if C is None:
_, C, _ = x.shape
# print("snake.conv_t.in:",x.shape)
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
# print("snake.conv_t.out:",x.shape)
x = x[..., self.pad_left:-self.pad_right]
else:
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * self.conv_transpose1d_block[0](x)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
def __init__(self, ratio=2, kernel_size=None, C=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
kernel_size=self.kernel_size,
C=C)
def forward(self, x):
xx = self.lowpass(x)

View File

@ -33,7 +33,7 @@ def load_model(model_path, device='cuda'):
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), C=None):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
@ -58,15 +58,15 @@ class ResBlock1(torch.nn.Module):
self.num_layers = len(self.convs1) + len(self.convs2)
self.activations = nn.ModuleList([
SnakeAlias(channels) for _ in range(self.num_layers)
SnakeAlias(channels, C=C) for _ in range(self.num_layers)
])
def forward(self, x):
def forward(self, x, DIM=None):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = a1(x, DIM)
xt = c1(xt)
xt = a2(xt)
xt = a2(xt, DIM)
xt = c2(xt)
x = xt + x
return x
@ -79,7 +79,7 @@ class ResBlock1(torch.nn.Module):
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), C=None):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
@ -92,12 +92,12 @@ class ResBlock2(torch.nn.Module):
self.num_layers = len(self.convs)
self.activations = nn.ModuleList([
SnakeAlias(channels) for _ in range(self.num_layers)
SnakeAlias(channels, C=C) for _ in range(self.num_layers)
])
def forward(self, x):
def forward(self, x, DIM=None):
for c,a in zip(self.convs, self.activations):
xt = a(x)
xt = a(x, DIM)
xt = c(xt)
x = xt + x
return x
@ -315,14 +315,14 @@ class Generator(torch.nn.Module):
self.snakes = nn.ModuleList()
for i in range(len(self.ups)):
ch = h["upsample_initial_channel"] // (2 ** (i + 1))
self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i))))
self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)), C = h["upsample_initial_channel"] >> i))
for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])):
self.resblocks.append(resblock(h, ch, k, d))
self.resblocks.append(resblock(h, ch, k, d, C = h["upsample_initial_channel"] >> (i + 1)))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.snake_post = SnakeAlias(ch)
self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups))
self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1)
def forward(self, x, f0, g=None):
@ -335,8 +335,9 @@ class Generator(torch.nn.Module):
x = x + self.cond(g)
# print(124,x.shape,har_source.shape)
for i in range(self.num_upsamples):
# print(f"self.snakes.{i}.pre:", x.shape)
x = self.snakes[i](x)
# print(3,x.shape)
# print(f"self.snakes.{i}.after:", x.shape)
x = self.ups[i](x)
x_source = self.noise_convs[i](har_source)
# print(4,x_source.shape,har_source.shape,x.shape)
@ -347,6 +348,7 @@ class Generator(torch.nn.Module):
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
# print(f"self.resblocks.{i}.after:", xs.shape)
x = xs / self.num_kernels
x = self.snake_post(x)
x = self.conv_post(x)