Compare commits

...

10 Commits

Author SHA1 Message Date
YuriHead 64591bd664
Merge pull request #372 from svc-develop-team/4.1-Latest
Synchronization
2023-08-03 18:01:44 +08:00
YuriHead fec31b364a
Merge pull request #371 from svc-develop-team/4.1-Stable
ruff fix
2023-08-03 17:55:38 +08:00
ylzz1997 947a5ccc00 ruff fix 2023-08-03 17:54:15 +08:00
YuriHead 1cb33c3e17
Merge pull request #370 from svc-develop-team/4.1-Stable
To Latest
2023-08-03 17:53:42 +08:00
YuriHead 6c02ae44c3
Merge branch '4.1-Latest' into 4.1-Stable 2023-08-03 17:49:52 +08:00
ylzz1997 5201200848 ruff fix 2023-08-03 17:45:12 +08:00
Ναρουσέ·μ·γιουμεμί·Χινακάννα 7c4d3a2036
Transformer Flow Onnx Export 2023-08-02 16:22:02 +08:00
Ναρουσέ·μ·γιουμεμί·Χινακάννα d71ffbf1b7
Add files via upload 2023-08-02 16:19:33 +08:00
ylzz1997 fc8336fffd Updata VITS2 part (Transformer Flow) 2023-08-02 01:09:46 +08:00
YuriHead 39b0befef5
Merge pull request #365 from svc-develop-team/4.1-Stable
To Latest
2023-08-02 00:43:07 +08:00
9 changed files with 130 additions and 10 deletions

View File

@ -8,6 +8,8 @@
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/svc-develop-team/so-vits-svc/blob/4.1-Stable/sovits4_for_colab.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/svc-develop-team/so-vits-svc/blob/4.1-Stable/sovits4_for_colab.ipynb)
[![Licence](https://img.shields.io/badge/LICENSE-AGPL3.0-green.svg?style=for-the-badge)](https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/LICENSE) [![Licence](https://img.shields.io/badge/LICENSE-AGPL3.0-green.svg?style=for-the-badge)](https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/LICENSE)
This round of limited time update is coming to an end, the warehouse will enter the Archieve state, please know
</div> </div>
> ✨ A studio that contains visible f0 editor, speaker mix timeline editor and other features (Where the Onnx models are used) : [MoeVoiceStudio](https://github.com/NaruseMioShirakana/MoeVoiceStudio) > ✨ A studio that contains visible f0 editor, speaker mix timeline editor and other features (Where the Onnx models are used) : [MoeVoiceStudio](https://github.com/NaruseMioShirakana/MoeVoiceStudio)

View File

@ -8,6 +8,8 @@
[![在Google Cloab中打开](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/svc-develop-team/so-vits-svc/blob/4.1-Stable/sovits4_for_colab.ipynb) [![在Google Cloab中打开](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/svc-develop-team/so-vits-svc/blob/4.1-Stable/sovits4_for_colab.ipynb)
[![LICENSE](https://img.shields.io/badge/LICENSE-AGPL3.0-green.svg?style=for-the-badge)](https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/LICENSE) [![LICENSE](https://img.shields.io/badge/LICENSE-AGPL3.0-green.svg?style=for-the-badge)](https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/LICENSE)
本轮限时更新即将结束仓库将进入Archieve状态望周知
</div> </div>

View File

@ -54,6 +54,7 @@
"upsample_initial_channel": 512, "upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16, 4, 4, 4], "upsample_kernel_sizes": [16,16, 4, 4, 4],
"n_layers_q": 3, "n_layers_q": 3,
"n_layers_trans_flow": 3,
"n_flow_layer": 4, "n_flow_layer": 4,
"use_spectral_norm": false, "use_spectral_norm": false,
"gin_channels": 768, "gin_channels": 768,
@ -65,7 +66,8 @@
"vol_embedding":false, "vol_embedding":false,
"use_depthwise_conv":false, "use_depthwise_conv":false,
"flow_share_parameter": false, "flow_share_parameter": false,
"use_automatic_f0_prediction": true "use_automatic_f0_prediction": true,
"use_transformer_flow": false
}, },
"spk": { "spk": {
"nyaru": 0, "nyaru": 0,

View File

@ -54,6 +54,7 @@
"upsample_initial_channel": 400, "upsample_initial_channel": 400,
"upsample_kernel_sizes": [16,16, 4, 4, 4], "upsample_kernel_sizes": [16,16, 4, 4, 4],
"n_layers_q": 3, "n_layers_q": 3,
"n_layers_trans_flow": 3,
"n_flow_layer": 4, "n_flow_layer": 4,
"use_spectral_norm": false, "use_spectral_norm": false,
"gin_channels": 768, "gin_channels": 768,
@ -65,7 +66,8 @@
"vol_embedding":false, "vol_embedding":false,
"use_depthwise_conv":true, "use_depthwise_conv":true,
"flow_share_parameter": true, "flow_share_parameter": true,
"use_automatic_f0_prediction": true "use_automatic_f0_prediction": true,
"use_transformer_flow": false
}, },
"spk": { "spk": {
"nyaru": 0, "nyaru": 0,

View File

@ -51,6 +51,46 @@ class ResidualCouplingBlock(nn.Module):
x = flow(x, x_mask, g=g, reverse=reverse) x = flow(x, x_mask, g=g, reverse=reverse)
return x return x
class TransformerCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
share_parameter=False
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
self.wn = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = self.gin_channels) if share_parameter else None
for i in range(n_flows):
self.flows.append(
modules.TransformerCouplingLayer(channels, hidden_channels, kernel_size, n_layers, n_heads, p_dropout, filter_channels, mean_only=True, wn_sharing_parameter=self.wn, gin_channels = self.gin_channels))
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, def __init__(self,
@ -327,6 +367,8 @@ class SynthesizerTrn(nn.Module):
use_automatic_f0_prediction = True, use_automatic_f0_prediction = True,
flow_share_parameter = False, flow_share_parameter = False,
n_flow_layer = 4, n_flow_layer = 4,
n_layers_trans_flow = 3,
use_transformer_flow = False,
**kwargs): **kwargs):
super().__init__() super().__init__()
@ -351,6 +393,7 @@ class SynthesizerTrn(nn.Module):
self.emb_g = nn.Embedding(n_speakers, gin_channels) self.emb_g = nn.Embedding(n_speakers, gin_channels)
self.use_depthwise_conv = use_depthwise_conv self.use_depthwise_conv = use_depthwise_conv
self.use_automatic_f0_prediction = use_automatic_f0_prediction self.use_automatic_f0_prediction = use_automatic_f0_prediction
self.n_layers_trans_flow = n_layers_trans_flow
if vol_embedding: if vol_embedding:
self.emb_vol = nn.Linear(1, hidden_channels) self.emb_vol = nn.Linear(1, hidden_channels)
@ -392,7 +435,10 @@ class SynthesizerTrn(nn.Module):
self.dec = Generator(h=hps) self.dec = Generator(h=hps)
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter) if use_transformer_flow:
self.flow = TransformerCouplingBlock(inter_channels, hidden_channels, filter_channels, n_heads, n_layers_trans_flow, 5, p_dropout, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
else:
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
if self.use_automatic_f0_prediction: if self.use_automatic_f0_prediction:
self.f0_decoder = F0Decoder( self.f0_decoder = F0Decoder(
1, 1,

View File

@ -5,12 +5,13 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import modules.commons as commons import modules.commons as commons
from modules.DSConv import weight_norm_modules
from modules.modules import LayerNorm from modules.modules import LayerNorm
class FFT(nn.Module): class FFT(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
proximal_bias=False, proximal_init=True, **kwargs): proximal_bias=False, proximal_init=True, isflow = False, **kwargs):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
@ -20,7 +21,11 @@ class FFT(nn.Module):
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.proximal_bias = proximal_bias self.proximal_bias = proximal_bias
self.proximal_init = proximal_init self.proximal_init = proximal_init
if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name='weight')
self.gin_channels = kwargs["gin_channels"]
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList() self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList() self.norm_layers_0 = nn.ModuleList()
@ -35,14 +40,25 @@ class FFT(nn.Module):
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask): def forward(self, x, x_mask, g = None):
""" """
x: decoder input x: decoder input
h: encoder output h: encoder output
""" """
if g is not None:
g = self.cond_layer(g)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
x = x * x_mask x = x * x_mask
for i in range(self.n_layers): for i in range(self.n_layers):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
x = commons.fused_add_tanh_sigmoid_multiply(
x,
g_l,
torch.IntTensor([self.hidden_channels]))
y = self.self_attn_layers[i](x, x, self_attn_mask) y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_0[i](x + y) x = self.norm_layers_0[i](x + y)

View File

@ -2,6 +2,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import modules.attentions as attentions
import modules.commons as commons import modules.commons as commons
from modules.commons import get_padding, init_weights from modules.commons import get_padding, init_weights
from modules.DSConv import ( from modules.DSConv import (
@ -304,3 +305,52 @@ class ResidualCouplingLayer(nn.Module):
x1 = (x1 - m) * torch.exp(-logs) * x_mask x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1) x = torch.cat([x0, x1], 1)
return x return x
class TransformerCouplingLayer(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
n_layers,
n_heads,
p_dropout=0,
filter_channels=0,
mean_only=False,
wn_sharing_parameter=None,
gin_channels = 0
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels]*2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1,2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x

View File

@ -1,8 +1,11 @@
import argparse
import json import json
import torch import torch
import utils import utils
from onnxexport.model_onnx_speaker_mix import SynthesizerTrn from onnxexport.model_onnx_speaker_mix import SynthesizerTrn
import argparse
parser = argparse.ArgumentParser(description='SoVitsSvc OnnxExport') parser = argparse.ArgumentParser(description='SoVitsSvc OnnxExport')
def OnnxExport(path=None): def OnnxExport(path=None):

View File

@ -1,14 +1,11 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn import Conv1d, Conv2d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
import modules.attentions as attentions import modules.attentions as attentions
import modules.commons as commons import modules.commons as commons
import modules.modules as modules import modules.modules as modules
import utils import utils
from modules.commons import get_padding
from utils import f0_to_coarse from utils import f0_to_coarse