diff --git a/vdecoder/hifiganwithsnake/models.py b/vdecoder/hifiganwithsnake/models.py index 9b64f9c..08bbda9 100644 --- a/vdecoder/hifiganwithsnake/models.py +++ b/vdecoder/hifiganwithsnake/models.py @@ -141,6 +141,7 @@ class SineGen(torch.nn.Module): self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold self.flag_for_pulse = flag_for_pulse + self.onnx = False def _f02uv(self, f0): # generate uv signal @@ -206,35 +207,82 @@ class SineGen(torch.nn.Module): sines = torch.cos(i_phase * 2 * np.pi) return sines - def forward(self, f0): + def forward(self, f0, upp=None): """ sine_tensor, uv = forward(f0) input F0: tensor(batchsize=1, length, dim=1) f0 for unvoiced steps should be 0 output sine_tensor: tensor(batchsize=1, length, dim) output uv: tensor(batchsize=1, length, 1) """ - with torch.no_grad(): - # fundamental component - fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + if self.onnx: + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=upp, + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + else: + with torch.no_grad(): + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) - # generate sine waveforms - sine_waves = self._f02sine(fn) * self.sine_amp + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp - # generate uv signal - # uv = torch.ones(f0.shape) - # uv = uv * (f0 > self.voiced_threshold) - uv = self._f02uv(f0) + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) - # noise: for unvoiced should be similar to sine_amp - # std = self.sine_amp/3 -> max value ~ self.sine_amp - # . for voiced regions is self.noise_std - noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 - noise = noise_amp * torch.randn_like(sine_waves) + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) - # first: set the unvoiced part to 0 by uv - # then: additive noise - sine_waves = sine_waves * uv + noise - return sine_waves, uv, noise + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise class SourceModuleHnNSF(torch.nn.Module): @@ -270,7 +318,7 @@ class SourceModuleHnNSF(torch.nn.Module): self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_tanh = torch.nn.Tanh() - def forward(self, x): + def forward(self, x, upp=None): """ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) F0_sampled (batchsize, length, 1) @@ -278,7 +326,7 @@ class SourceModuleHnNSF(torch.nn.Module): noise_source (batchsize, length 1) """ # source for harmonic branch - sine_wavs, uv, _ = self.l_sin_gen(x) + sine_wavs, uv, _ = self.l_sin_gen(x, upp) sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) # source for noise branch, in the same shape as uv @@ -325,12 +373,19 @@ class Generator(torch.nn.Module): self.conv_post.apply(init_weights) self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups)) self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) + self.upp = np.prod(h["upsample_rates"]) + self.onnx = False + + def OnnxExport(self): + self.onnx = True + self.m_source.l_sin_gen.onnx = True def forward(self, x, f0, g=None): # print(1,x.shape,f0.shape,f0[:, None].shape) - f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + if not self.onnx: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t # print(2,f0.shape) - har_source, noi_source, uv = self.m_source(f0) + har_source, noi_source, uv = self.m_source(f0, self.upp) har_source = har_source.transpose(1, 2) x = self.conv_pre(x) x = x + self.cond(g)