Snake 的相同BUG修复

This commit is contained in:
白叶 藤原 2023-07-17 17:57:03 +08:00
parent 90c9ccc6a8
commit 85024991e2
1 changed files with 78 additions and 23 deletions

View File

@ -141,6 +141,7 @@ class SineGen(torch.nn.Module):
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
self.flag_for_pulse = flag_for_pulse
self.onnx = False
def _f02uv(self, f0):
# generate uv signal
@ -206,35 +207,82 @@ class SineGen(torch.nn.Module):
sines = torch.cos(i_phase * 2 * np.pi)
return sines
def forward(self, f0):
def forward(self, f0, upp=None):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
with torch.no_grad():
# fundamental component
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
if self.onnx:
with torch.no_grad():
f0 = f0[:, None].transpose(1, 2)
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
# fundamental component
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in np.arange(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
idx + 2
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
rand_ini = torch.rand(
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one *= upp
tmp_over_one = F.interpolate(
tmp_over_one.transpose(2, 1),
scale_factor=upp,
mode="linear",
align_corners=True,
).transpose(2, 1)
rad_values = F.interpolate(
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
).transpose(
2, 1
) #######
tmp_over_one %= 1
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sine_waves = torch.sin(
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
)
sine_waves = sine_waves * self.sine_amp
uv = self._f02uv(f0)
uv = F.interpolate(
uv.transpose(2, 1), scale_factor=upp, mode="nearest"
).transpose(2, 1)
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
else:
with torch.no_grad():
# fundamental component
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
# generate sine waveforms
sine_waves = self._f02sine(fn) * self.sine_amp
# generate sine waveforms
sine_waves = self._f02sine(fn) * self.sine_amp
# generate uv signal
# uv = torch.ones(f0.shape)
# uv = uv * (f0 > self.voiced_threshold)
uv = self._f02uv(f0)
# generate uv signal
# uv = torch.ones(f0.shape)
# uv = uv * (f0 > self.voiced_threshold)
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
@ -270,7 +318,7 @@ class SourceModuleHnNSF(torch.nn.Module):
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
def forward(self, x, upp=None):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
@ -278,7 +326,7 @@ class SourceModuleHnNSF(torch.nn.Module):
noise_source (batchsize, length 1)
"""
# source for harmonic branch
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype)))
# source for noise branch, in the same shape as uv
@ -325,12 +373,19 @@ class Generator(torch.nn.Module):
self.conv_post.apply(init_weights)
self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups))
self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1)
self.upp = np.prod(h["upsample_rates"])
self.onnx = False
def OnnxExport(self):
self.onnx = True
self.m_source.l_sin_gen.onnx = True
def forward(self, x, f0, g=None):
# print(1,x.shape,f0.shape,f0[:, None].shape)
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
if not self.onnx:
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
# print(2,f0.shape)
har_source, noi_source, uv = self.m_source(f0)
har_source, noi_source, uv = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2)
x = self.conv_pre(x)
x = x + self.cond(g)