Updata VITS2 part (Transformer Flow)

This commit is contained in:
ylzz1997 2023-08-02 01:09:46 +08:00
parent 39b0befef5
commit fc8336fffd
5 changed files with 122 additions and 6 deletions

View File

@ -54,6 +54,7 @@
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16, 4, 4, 4],
"n_layers_q": 3,
"n_layers_trans_flow": 3,
"n_flow_layer": 4,
"use_spectral_norm": false,
"gin_channels": 768,
@ -65,7 +66,8 @@
"vol_embedding":false,
"use_depthwise_conv":false,
"flow_share_parameter": false,
"use_automatic_f0_prediction": true
"use_automatic_f0_prediction": true,
"use_transformer_flow": false
},
"spk": {
"nyaru": 0,

View File

@ -54,6 +54,7 @@
"upsample_initial_channel": 400,
"upsample_kernel_sizes": [16,16, 4, 4, 4],
"n_layers_q": 3,
"n_layers_trans_flow": 3,
"n_flow_layer": 4,
"use_spectral_norm": false,
"gin_channels": 768,
@ -65,7 +66,8 @@
"vol_embedding":false,
"use_depthwise_conv":true,
"flow_share_parameter": true,
"use_automatic_f0_prediction": true
"use_automatic_f0_prediction": true,
"use_transformer_flow": false
},
"spk": {
"nyaru": 0,

View File

@ -51,6 +51,46 @@ class ResidualCouplingBlock(nn.Module):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class TransformerCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
share_parameter=False
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
self.wn = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = self.gin_channels) if share_parameter else None
for i in range(n_flows):
self.flows.append(
modules.TransformerCouplingLayer(channels, hidden_channels, kernel_size, n_layers, n_heads, p_dropout, filter_channels, mean_only=True, wn_sharing_parameter=self.wn, gin_channels = self.gin_channels))
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class Encoder(nn.Module):
def __init__(self,
@ -327,6 +367,8 @@ class SynthesizerTrn(nn.Module):
use_automatic_f0_prediction = True,
flow_share_parameter = False,
n_flow_layer = 4,
n_layers_trans_flow = 3,
use_transformer_flow = False,
**kwargs):
super().__init__()
@ -351,6 +393,7 @@ class SynthesizerTrn(nn.Module):
self.emb_g = nn.Embedding(n_speakers, gin_channels)
self.use_depthwise_conv = use_depthwise_conv
self.use_automatic_f0_prediction = use_automatic_f0_prediction
self.n_layers_trans_flow = n_layers_trans_flow
if vol_embedding:
self.emb_vol = nn.Linear(1, hidden_channels)
@ -392,7 +435,10 @@ class SynthesizerTrn(nn.Module):
self.dec = Generator(h=hps)
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
if use_transformer_flow:
self.flow = TransformerCouplingBlock(inter_channels, hidden_channels, filter_channels, n_heads, n_layers_trans_flow, 5, p_dropout, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
else:
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
if self.use_automatic_f0_prediction:
self.f0_decoder = F0Decoder(
1,

View File

@ -5,12 +5,13 @@ from torch import nn
from torch.nn import functional as F
import modules.commons as commons
from modules.DSConv import weight_norm_modules
from modules.modules import LayerNorm
class FFT(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
proximal_bias=False, proximal_init=True, **kwargs):
proximal_bias=False, proximal_init=True, isflow = False, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
@ -20,7 +21,11 @@ class FFT(nn.Module):
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name='weight')
self.gin_channels = kwargs["gin_channels"]
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
@ -35,14 +40,25 @@ class FFT(nn.Module):
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
self.norm_layers_1.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
def forward(self, x, x_mask, g = None):
"""
x: decoder input
h: encoder output
"""
if g is not None:
g = self.cond_layer(g)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
x = x * x_mask
for i in range(self.n_layers):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
x = commons.fused_add_tanh_sigmoid_multiply(
x,
g_l,
torch.IntTensor([self.hidden_channels]))
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)

View File

@ -2,6 +2,7 @@ import torch
from torch import nn
from torch.nn import functional as F
import modules.attentions as attentions
import modules.commons as commons
from modules.commons import get_padding, init_weights
from modules.DSConv import (
@ -304,3 +305,52 @@ class ResidualCouplingLayer(nn.Module):
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
class TransformerCouplingLayer(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
n_layers,
n_heads,
p_dropout=0,
filter_channels=0,
mean_only=False,
wn_sharing_parameter=None,
gin_channels = 0
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels]*2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1,2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x