Merge pull request #292 from svc-develop-team/4.1-Latest

4.1 latest
This commit is contained in:
YuriHead 2023-07-13 00:23:44 +08:00 committed by GitHub
commit 1919391b31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 681 additions and 13 deletions

View File

@ -145,6 +145,8 @@ While the pretrained model typically does not pose copyright concerns, it is ess
#### **Optional(Select as Required)**
##### NSF-HIFIGAN
If you are using the `NSF-HIFIGAN enhancer` or `shallow diffusion`, you will need to download the pre-trained NSF-HIFIGAN model.
- Pre-trained NSF-HIFIGAN Vocoder: [nsf_hifigan_20221211.zip](https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip)
@ -158,6 +160,13 @@ unzip -od pretrain/nsf_hifigan pretrain/nsf_hifigan_20221211.zip
# URL: https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1
```
##### RMVPE
If you are using the `rmvpe` F0 Predictor, you will need to download the pre-trained RMVPE model.
- download model at [rmvpe.pt](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.pt)
- Place it under the `pretrain` directory
## 📊 Dataset Preparation
Simply place the dataset in the `dataset_raw` directory with the following file structure:
@ -278,13 +287,14 @@ nsf-snake-hifigan
python preprocess_hubert_f0.py --f0_predictor dio
```
f0_predictor has four options
f0_predictor has the following options
```
crepe
dio
pm
harvest
rmvpe
```
If the training set is too noisy,it is recommended to use `crepe` to handle f0
@ -336,7 +346,7 @@ Required parameters:
Optional parameters: see the next section
- `-lg` | `--linear_gradient`: The cross fade length of two audio slices in seconds. If there is a discontinuous voice after forced slicing, you can adjust this value. Otherwise, it is recommended to use the default value of 0.
- `-f0p` | `--f0_predictor`: Select a F0 predictor, options are `crepe`, `pm`, `dio`, `harvest`, default value is `pm`(note: f0 mean pooling will be enable when using `crepe`)
- `-f0p` | `--f0_predictor`: Select a F0 predictor, options are `crepe`, `pm`, `dio`, `harvest`, `rmvpe`, default value is `pm`(note: f0 mean pooling will be enable when using `crepe`)
- `-a` | `--auto_predict_f0`: automatic pitch prediction, do not enable this when converting singing voices as it can cause serious pitch issues.
- `-cm` | `--cluster_model_path`: Cluster model or feature retrieval index path, if left blank, it will be automatically set as the default path of these models. If there is no training cluster or feature retrieval, fill in at will.
- `-cr` | `--cluster_infer_ratio`: The proportion of clustering scheme or feature retrieval ranges from 0 to 1. If there is no training clustering model or feature retrieval, the default is 0.
@ -474,6 +484,7 @@ Note: For Hubert Onnx models, please use the models provided by MoeSS. Currently
|[aes35-000039](https://www.aes.org/e-lib/online/browse.cfm?elib=15165) | Dio (F0 Predictor) | Fast and reliable F0 estimation method based on the period extraction of vocal fold vibration of singing voice and speech | [mmorise/World/dio](https://github.com/mmorise/World/blob/master/src/dio.cpp) |
|[8461329](https://ieeexplore.ieee.org/document/8461329) | Crepe (F0 Predictor) | Crepe: A Convolutional Representation for Pitch Estimation | [maxrmorrison/torchcrepe](https://github.com/maxrmorrison/torchcrepe) |
|[DOI:10.1016/j.wocn.2018.07.001](https://doi.org/10.1016/j.wocn.2018.07.001) | Parselmouth (F0 Predictor) | Introducing Parselmouth: A Python interface to Praat | [YannickJadoul/Parselmouth](https://github.com/YannickJadoul/Parselmouth) |
|[2306.15412v2](https://arxiv.org/abs/2306.15412v2) | RMVPE (F0 Predictor) | RMVPE: A Robust Model for Vocal Pitch Estimation in Polyphonic Music | [Dream-High/RMVPE](https://github.com/Dream-High/RMVPE) |
|[2010.05646](https://arxiv.org/abs/2010.05646) | HIFIGAN (Vocoder) | HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | [jik876/hifi-gan](https://github.com/jik876/hifi-gan) |
|[1810.11946](https://arxiv.org/abs/1810.11946.pdf) | NSF (Vocoder) | Neural source-filter-based waveform model for statistical parametric speech synthesis | [openvpi/DiffSinger/modules/nsf_hifigan](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan)
|[2006.08195](https://arxiv.org/abs/2006.08195) | Snake (Vocoder) | Neural Networks Fail to Learn Periodic Functions and How to Fix It | [EdwardDixon/snake](https://github.com/EdwardDixon/snake)

View File

@ -145,6 +145,8 @@ wget -P pretrain/ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/mai
#### **可选项(根据情况选择)**
##### NSF-HIFIGAN
如果使用`NSF-HIFIGAN 增强器`或`浅层扩散`的话,需要下载预训练的 NSF-HIFIGAN 模型,如果不需要可以不下载
+ 预训练的 NSF-HIFIGAN 声码器 [nsf_hifigan_20221211.zip](https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip)
@ -158,6 +160,14 @@ unzip -od pretrain/nsf_hifigan pretrain/nsf_hifigan_20221211.zip
# 地址https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1
```
##### RMVPE
如果使用`rmvpe`F0预测器的话需要下载预训练的 RMVPE 模型
+ 下载模型 [rmvpe.pt](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.pt)
+ 放在`pretrain`目录下
## 📊 数据集准备
仅需要以以下文件结构将数据集放入 dataset_raw 目录即可
@ -280,13 +290,14 @@ nsf-snake-hifigan
python preprocess_hubert_f0.py --f0_predictor dio
```
f0_predictor 拥有四个选择
f0_predictor 拥有以下选择
```
crepe
dio
pm
harvest
rmvpe
```
如果训练集过于嘈杂,请使用 crepe 处理 f0
@ -338,7 +349,7 @@ python inference_main.py -m "logs/44k/G_30400.pth" -c "configs/config.json" -n "
可选项部分:部分具体见下一节
+ `-lg` | `--linear_gradient`:两段音频切片的交叉淡入长度,如果强制切片后出现人声不连贯可调整该数值,如果连贯建议采用默认值 0单位为秒
+ `-f0p` | `--f0_predictor`:选择 F0 预测器,可选择 crepe,pm,dio,harvest, 默认为 pm注意crepe 为原 F0 使用均值滤波器)
+ `-f0p` | `--f0_predictor`:选择 F0 预测器,可选择 crepe,pm,dio,harvest,rmvpe, 默认为 pm注意crepe 为原 F0 使用均值滤波器)
+ `-a` | `--auto_predict_f0`:语音转换自动预测音高,转换歌声时不要打开这个会严重跑调
+ `-cm` | `--cluster_model_path`:聚类模型或特征检索索引路径,留空则自动设为各方案模型的默认路径,如果没有训练聚类或特征检索则随便填
+ `-cr` | `--cluster_infer_ratio`:聚类方案或特征检索占比,范围 0-1若没有训练聚类模型或特征检索则默认 0 即可
@ -474,6 +485,7 @@ python compress_model.py -c="configs/config.json" -i="logs/44k/G_30400.pth" -o="
|[aes35-000039](https://www.aes.org/e-lib/online/browse.cfm?elib=15165) | Dio (F0 Predictor) | Fast and reliable F0 estimation method based on the period extraction of vocal fold vibration of singing voice and speech | [mmorise/World/dio](https://github.com/mmorise/World/blob/master/src/dio.cpp) |
|[8461329](https://ieeexplore.ieee.org/document/8461329) | Crepe (F0 Predictor) | Crepe: A Convolutional Representation for Pitch Estimation | [maxrmorrison/torchcrepe](https://github.com/maxrmorrison/torchcrepe) |
|[DOI:10.1016/j.wocn.2018.07.001](https://doi.org/10.1016/j.wocn.2018.07.001) | Parselmouth (F0 Predictor) | Introducing Parselmouth: A Python interface to Praat | [YannickJadoul/Parselmouth](https://github.com/YannickJadoul/Parselmouth) |
|[2306.15412v2](https://arxiv.org/abs/2306.15412v2) | RMVPE (F0 Predictor) | RMVPE: A Robust Model for Vocal Pitch Estimation in Polyphonic Music | [Dream-High/RMVPE](https://github.com/Dream-High/RMVPE) |
|[2010.05646](https://arxiv.org/abs/2010.05646) | HIFIGAN (Vocoder) | HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | [jik876/hifi-gan](https://github.com/jik876/hifi-gan) |
|[1810.11946](https://arxiv.org/abs/1810.11946.pdf) | NSF (Vocoder) | Neural source-filter-based waveform model for statistical parametric speech synthesis | [openvpi/DiffSinger/modules/nsf_hifigan](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan)
|[2006.08195](https://arxiv.org/abs/2006.08195) | Snake (Vocoder) | Neural Networks Fail to Learn Periodic Functions and How to Fix It | [EdwardDixon/snake](https://github.com/EdwardDixon/snake)

View File

@ -34,6 +34,7 @@ def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: s
new_dict_g = copyStateDict(state_dict_g)
keys = []
for k, v in new_dict_g['model'].items():
if "enc_q" in k: continue # noqa: E701
keys.append(k)
new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys}

View File

@ -29,7 +29,7 @@ def main():
parser.add_argument('-cm', '--cluster_model_path', type=str, default="", help='聚类模型或特征检索索引路径,留空则自动设为各方案模型的默认路径,如果没有训练聚类或特征检索则随便填')
parser.add_argument('-cr', '--cluster_infer_ratio', type=float, default=0, help='聚类方案或特征检索占比范围0-1若没有训练聚类模型或特征检索则默认0即可')
parser.add_argument('-lg', '--linear_gradient', type=float, default=0, help='两段音频切片的交叉淡入长度如果强制切片后出现人声不连贯可调整该数值如果连贯建议采用默认值0单位为秒')
parser.add_argument('-f0p', '--f0_predictor', type=str, default="pm", help='选择F0预测器,可选择crepe,pm,dio,harvest,默认为pm(注意crepe为原F0使用均值滤波器)')
parser.add_argument('-f0p', '--f0_predictor', type=str, default="pm", help='选择F0预测器,可选择crepe,pm,dio,harvest,rmvpe,默认为pm(注意crepe为原F0使用均值滤波器)')
parser.add_argument('-eh', '--enhance', action='store_true', default=False, help='是否使用NSF_HIFIGAN增强器,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭')
parser.add_argument('-shd', '--shallow_diffusion', action='store_true', default=False, help='是否使用浅层扩散使用后可解决一部分电音问题默认关闭该选项打开时NSF_HIFIGAN增强器将会被禁止')
parser.add_argument('-usm', '--use_spk_mix', action='store_true', default=False, help='是否使用角色融合')

View File

@ -0,0 +1,106 @@
from typing import Union
import numpy as np
import torch
import torch.nn.functional as F
from modules.F0Predictor.F0Predictor import F0Predictor
from .rmvpe import RMVPE
class RMVPEF0Predictor(F0Predictor):
def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05):
self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device)
self.hop_length = hop_length
self.f0_min = f0_min
self.f0_max = f0_max
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
self.threshold = threshold
self.sampling_rate = sampling_rate
self.dtype = dtype
def repeat_expand(
self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
):
ndim = content.ndim
if content.ndim == 1:
content = content[None, None]
elif content.ndim == 2:
content = content[None]
assert content.ndim == 3
is_np = isinstance(content, np.ndarray)
if is_np:
content = torch.from_numpy(content)
results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
if is_np:
results = results.numpy()
if ndim == 1:
return results[0, 0]
elif ndim == 2:
return results[0]
def post_process(self, x, sampling_rate, f0, pad_to):
if isinstance(f0, np.ndarray):
f0 = torch.from_numpy(f0).float().to(x.device)
if pad_to is None:
return f0
f0 = self.repeat_expand(f0, pad_to)
vuv_vector = torch.zeros_like(f0)
vuv_vector[f0 > 0.0] = 1.0
vuv_vector[f0 <= 0.0] = 0.0
# 去掉0频率, 并线性插值
nzindex = torch.nonzero(f0).squeeze()
f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
if f0.shape[0] <= 0:
return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy()
if f0.shape[0] == 1:
return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy()
# 大概可以用 torch 重写?
f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
#vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
return f0,vuv_vector.cpu().numpy()
def compute_f0(self,wav,p_len=None):
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
if p_len is None:
p_len = x.shape[0]//self.hop_length
else:
assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
if torch.all(f0 == 0):
rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
return rtn,rtn
return self.post_process(x,self.sampling_rate,f0,p_len)[0]
def compute_f0_uv(self,wav,p_len=None):
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
if p_len is None:
p_len = x.shape[0]//self.hop_length
else:
assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
if torch.all(f0 == 0):
rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
return rtn,rtn
return self.post_process(x,self.sampling_rate,f0,p_len)

View File

@ -0,0 +1,10 @@
from .constants import * # noqa: F403
from .inference import RMVPE # noqa: F401
from .model import E2E, E2E0 # noqa: F401
from .spec import MelSpectrogram # noqa: F401
from .utils import ( # noqa: F401
cycle,
summary,
to_local_average_cents,
to_viterbi_cents,
)

View File

@ -0,0 +1,9 @@
SAMPLE_RATE = 16000
N_CLASS = 360
N_MELS = 128
MEL_FMIN = 30
MEL_FMAX = SAMPLE_RATE // 2
WINDOW_LENGTH = 1024
CONST = 1997.3794084376191

View File

@ -0,0 +1,190 @@
import torch
import torch.nn as nn
from .constants import N_MELS
class ConvBlockRes(nn.Module):
def __init__(self, in_channels, out_channels, momentum=0.01):
super(ConvBlockRes, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
self.is_shortcut = True
else:
self.is_shortcut = False
def forward(self, x):
if self.is_shortcut:
return self.conv(x) + self.shortcut(x)
else:
return self.conv(x) + x
class ResEncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
super(ResEncoderBlock, self).__init__()
self.n_blocks = n_blocks
self.conv = nn.ModuleList()
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
for i in range(n_blocks - 1):
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
self.kernel_size = kernel_size
if self.kernel_size is not None:
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
def forward(self, x):
for i in range(self.n_blocks):
x = self.conv[i](x)
if self.kernel_size is not None:
return x, self.pool(x)
else:
return x
class ResDecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
super(ResDecoderBlock, self).__init__()
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
self.n_blocks = n_blocks
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=stride,
padding=(1, 1),
output_padding=out_padding,
bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
)
self.conv2 = nn.ModuleList()
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
for i in range(n_blocks-1):
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
def forward(self, x, concat_tensor):
x = self.conv1(x)
x = torch.cat((x, concat_tensor), dim=1)
for i in range(self.n_blocks):
x = self.conv2[i](x)
return x
class Encoder(nn.Module):
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
super(Encoder, self).__init__()
self.n_encoders = n_encoders
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
self.layers = nn.ModuleList()
self.latent_channels = []
for i in range(self.n_encoders):
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
self.latent_channels.append([out_channels, in_size])
in_channels = out_channels
out_channels *= 2
in_size //= 2
self.out_size = in_size
self.out_channel = out_channels
def forward(self, x):
concat_tensors = []
x = self.bn(x)
for i in range(self.n_encoders):
_, x = self.layers[i](x)
concat_tensors.append(_)
return x, concat_tensors
class Intermediate(nn.Module):
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
super(Intermediate, self).__init__()
self.n_inters = n_inters
self.layers = nn.ModuleList()
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
for i in range(self.n_inters-1):
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
def forward(self, x):
for i in range(self.n_inters):
x = self.layers[i](x)
return x
class Decoder(nn.Module):
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
super(Decoder, self).__init__()
self.layers = nn.ModuleList()
self.n_decoders = n_decoders
for i in range(self.n_decoders):
out_channels = in_channels // 2
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
in_channels = out_channels
def forward(self, x, concat_tensors):
for i in range(self.n_decoders):
x = self.layers[i](x, concat_tensors[-1-i])
return x
class TimbreFilter(nn.Module):
def __init__(self, latent_rep_channels):
super(TimbreFilter, self).__init__()
self.layers = nn.ModuleList()
for latent_rep in latent_rep_channels:
self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0]))
def forward(self, x_tensors):
out_tensors = []
for i, layer in enumerate(self.layers):
out_tensors.append(layer(x_tensors[i]))
return out_tensors
class DeepUnet(nn.Module):
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
super(DeepUnet, self).__init__()
self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
self.tf = TimbreFilter(self.encoder.latent_channels)
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
def forward(self, x):
x, concat_tensors = self.encoder(x)
x = self.intermediate(x)
concat_tensors = self.tf(concat_tensors)
x = self.decoder(x, concat_tensors)
return x
class DeepUnet0(nn.Module):
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
super(DeepUnet0, self).__init__()
self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
self.tf = TimbreFilter(self.encoder.latent_channels)
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
def forward(self, x):
x, concat_tensors = self.encoder(x)
x = self.intermediate(x)
x = self.decoder(x, concat_tensors)
return x

View File

@ -0,0 +1,57 @@
import torch
import torch.nn.functional as F
from torchaudio.transforms import Resample
from .constants import * # noqa: F403
from .model import E2E0
from .spec import MelSpectrogram
from .utils import to_local_average_cents, to_viterbi_cents
class RMVPE:
def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=160):
self.resample_kernel = {}
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
model = E2E0(4, 1, (2, 2))
ckpt = torch.load(model_path)
model.load_state_dict(ckpt['model'])
model = model.to(dtype).to(self.device)
model.eval()
self.model = model
self.dtype = dtype
self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405
self.resample_kernel = {}
def mel2hidden(self, mel):
with torch.no_grad():
n_frames = mel.shape[-1]
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect')
hidden = self.model(mel)
return hidden[:, :n_frames]
def decode(self, hidden, thred=0.03, use_viterbi=False):
if use_viterbi:
cents_pred = to_viterbi_cents(hidden, thred=thred)
else:
cents_pred = to_local_average_cents(hidden, thred=thred)
f0 = torch.Tensor([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]).to(self.device)
return f0
def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=False):
audio = audio.unsqueeze(0).to(self.dtype).to(self.device)
if sample_rate == 16000:
audio_res = audio
else:
key_str = str(sample_rate)
if key_str not in self.resample_kernel:
self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device)
audio_res = self.resample_kernel[key_str](audio)
mel_extractor = self.mel_extractor.to(self.device)
mel = mel_extractor(audio_res, center=True).to(self.dtype)
hidden = self.mel2hidden(mel)
f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi)
return f0

View File

@ -0,0 +1,67 @@
from torch import nn
from .constants import * # noqa: F403
from .deepunet import DeepUnet, DeepUnet0
from .seq import BiGRU
from .spec import MelSpectrogram
class E2E(nn.Module):
def __init__(self, hop_length, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
en_out_channels=16):
super(E2E, self).__init__()
self.mel = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405
nn.Linear(512, N_CLASS), # noqa: F405
nn.Dropout(0.25),
nn.Sigmoid()
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405
nn.Dropout(0.25),
nn.Sigmoid()
)
def forward(self, x):
mel = self.mel(x.reshape(-1, x.shape[-1])).transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
# x = self.fc(x)
hidden_vec = 0
if len(self.fc) == 4:
for i in range(len(self.fc)):
x = self.fc[i](x)
if i == 0:
hidden_vec = x
return hidden_vec, x
class E2E0(nn.Module):
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
en_out_channels=16):
super(E2E0, self).__init__()
self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405
nn.Linear(512, N_CLASS), # noqa: F405
nn.Dropout(0.25),
nn.Sigmoid()
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405
nn.Dropout(0.25),
nn.Sigmoid()
)
def forward(self, mel):
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
x = self.fc(x)
return x

View File

@ -0,0 +1,20 @@
import torch.nn as nn
class BiGRU(nn.Module):
def __init__(self, input_features, hidden_features, num_layers):
super(BiGRU, self).__init__()
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
def forward(self, x):
return self.gru(x)[0]
class BiLSTM(nn.Module):
def __init__(self, input_features, hidden_features, num_layers):
super(BiLSTM, self).__init__()
self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
def forward(self, x):
return self.lstm(x)[0]

View File

@ -0,0 +1,67 @@
import numpy as np
import torch
import torch.nn.functional as F
from librosa.filters import mel
class MelSpectrogram(torch.nn.Module):
def __init__(
self,
n_mel_channels,
sampling_rate,
win_length,
hop_length,
n_fft=None,
mel_fmin=0,
mel_fmax=None,
clamp = 1e-5
):
super().__init__()
n_fft = win_length if n_fft is None else n_fft
self.hann_window = {}
mel_basis = mel(
sr=sampling_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax,
htk=True)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.n_fft = win_length if n_fft is None else n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
self.clamp = clamp
def forward(self, audio, keyshift=0, speed=1, center=True):
factor = 2 ** (keyshift / 12)
n_fft_new = int(np.round(self.n_fft * factor))
win_length_new = int(np.round(self.win_length * factor))
hop_length_new = int(np.round(self.hop_length * speed))
keyshift_key = str(keyshift)+'_'+str(audio.device)
if keyshift_key not in self.hann_window:
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
fft = torch.stft(
audio,
n_fft=n_fft_new,
hop_length=hop_length_new,
win_length=win_length_new,
window=self.hann_window[keyshift_key],
center=center,
return_complex=True)
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
if keyshift != 0:
size = self.n_fft // 2 + 1
resize = magnitude.size(1)
if resize < size:
magnitude = F.pad(magnitude, (0, 0, 0, size-resize))
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
return log_mel_spec

View File

@ -0,0 +1,107 @@
import sys
from functools import reduce
import librosa
import numpy as np
import torch
from torch.nn.modules.module import _addindent
from .constants import * # noqa: F403
def cycle(iterable):
while True:
for item in iterable:
yield item
def summary(model, file=sys.stdout):
def repr(model):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
total_params = 0
for key, module in model._modules.items():
mod_str, num_params = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
total_params += num_params
lines = extra_lines + child_lines
for name, p in model._parameters.items():
if hasattr(p, 'shape'):
total_params += reduce(lambda x, y: x * y, p.shape)
main_str = model._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
if file is sys.stdout:
main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
else:
main_str += ', {:,} params'.format(total_params)
return main_str, total_params
string, count = repr(model)
if file is not None:
if isinstance(file, str):
file = open(file, 'w')
print(string, file=file)
file.flush()
return count
def to_local_average_cents(salience, center=None, thred=0.05):
"""
find the weighted average cents near the argmax bin
"""
if not hasattr(to_local_average_cents, 'cents_mapping'):
# the bin number-to-cents mapping
to_local_average_cents.cents_mapping = (
20 * torch.arange(N_CLASS) + CONST).to(salience.device) # noqa: F405
if salience.ndim == 1:
if center is None:
center = int(torch.argmax(salience))
start = max(0, center - 4)
end = min(len(salience), center + 5)
salience = salience[start:end]
product_sum = torch.sum(
salience * to_local_average_cents.cents_mapping[start:end])
weight_sum = torch.sum(salience)
return product_sum / weight_sum if torch.max(salience) > thred else 0
if salience.ndim == 2:
return torch.Tensor([to_local_average_cents(salience[i, :], None, thred) for i in
range(salience.shape[0])]).to(salience.device)
raise Exception("label should be either 1d or 2d ndarray")
def to_viterbi_cents(salience, thred=0.05):
# Create viterbi transition matrix
if not hasattr(to_viterbi_cents, 'transition'):
xx, yy = torch.meshgrid(range(N_CLASS), range(N_CLASS)) # noqa: F405
transition = torch.maximum(30 - abs(xx - yy), 0)
transition = transition / transition.sum(axis=1, keepdims=True)
to_viterbi_cents.transition = transition
# Convert to probability
prob = salience.T
prob = prob / prob.sum(axis=0)
# Perform viterbi decoding
path = librosa.sequence.viterbi(prob.detach().cpu().numpy(), to_viterbi_cents.transition).astype(np.int64)
return torch.Tensor([to_local_average_cents(salience[i, :], path[i], thred) for i in
range(len(path))]).to(salience.device)

View File

@ -134,7 +134,7 @@ if __name__ == "__main__":
'--use_diff',action='store_true', help='Whether to use the diffusion model'
)
parser.add_argument(
'--f0_predictor', type=str, default="dio", help='Select F0 predictor, can select crepe,pm,dio,harvest, default pm(note: crepe is original F0 using mean filter)'
'--f0_predictor', type=str, default="dio", help='Select F0 predictor, can select crepe,pm,dio,harvest,rmvpe, default pm(note: crepe is original F0 using mean filter)'
)
parser.add_argument(
'--num_processes', type=int, default=1, help='You are advised to set the number of processes to the same as the number of CPU cores'

View File

@ -317,9 +317,12 @@
"#@markdown\n",
"%cd /content/so-vits-svc\n",
"\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\"]\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n",
"use_diff = True #@param {type:\"boolean\"}\n",
"\n",
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
"\n",
"diff_param = \"\"\n",
"if use_diff:\n",
" diff_param = \"--use_diff\"\n",
@ -602,7 +605,7 @@
"if auto_predict_f0:\n",
" apf = \" -a \"\n",
"\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\"]\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n",
"\n",
"enhance = False #@param {type:\"boolean\"}\n",
"ehc = \"\"\n",
@ -618,6 +621,10 @@
"url, output = get_speech_encoder(config_path)\n",
"\n",
"import os\n",
"\n",
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
"\n",
"if not os.path.exists(output):\n",
" !curl -L {url} -o {output}\n",
"\n",

View File

@ -140,7 +140,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if writers is not None:
writer, writer_eval = writers
half_type = torch.float16 if hps.train.half_type=="fp16" else torch.bfloat16
half_type = torch.bfloat16 if hps.train.half_type=="bf16" else torch.float16
# train_loader.batch_sampler.set_epoch(epoch)
global global_step

View File

@ -96,7 +96,10 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
elif f0_predictor == "dio":
from modules.F0Predictor.DioF0Predictor import DioF0Predictor
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
elif f0_predictor == "rmvpe":
from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
else:
raise Exception("Unknown f0 predictor")
return f0_predictor_object

View File

@ -4,6 +4,7 @@ import logging
import os
import re
import subprocess
import sys
import time
import traceback
from itertools import chain
@ -224,9 +225,9 @@ def vc_fn2(_text, _lang, _gender, _rate, _volume, sid, output_format, vc_transfo
_volume = f"+{int(_volume*100)}%" if _volume >= 0 else f"{int(_volume*100)}%"
if _lang == "Auto":
_gender = "Male" if _gender == "" else "Female"
subprocess.run([r"python", "edgetts/tts.py", _text, _lang, _rate, _volume, _gender])
subprocess.run([sys.executable, "edgetts/tts.py", _text, _lang, _rate, _volume, _gender])
else:
subprocess.run([r"python", "edgetts/tts.py", _text, _lang, _rate, _volume])
subprocess.run([sys.executable, "edgetts/tts.py", _text, _lang, _rate, _volume])
target_sr = 44100
y, sr = librosa.load("tts.wav")
resampled_y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
@ -322,7 +323,7 @@ with gr.Blocks(
<font size=2> 推理设置</font>
""")
auto_f0 = gr.Checkbox(label="自动f0预测配合聚类模型f0预测效果更好,会导致变调功能失效(仅限转换语音,歌声勾选此项会究极跑调)", value=False)
f0_predictor = gr.Dropdown(label="选择F0预测器,可选择crepe,pm,dio,harvest,默认为pm(注意crepe为原F0使用均值滤波器)", choices=["pm","dio","harvest","crepe"], value="pm")
f0_predictor = gr.Dropdown(label="选择F0预测器,可选择crepe,pm,dio,harvest,rmvpe,默认为pm(注意crepe为原F0使用均值滤波器)", choices=["pm","dio","harvest","crepe","rmvpe"], value="pm")
vc_transform = gr.Number(label="变调整数可以正负半音数量升高八度就是12", value=0)
cluster_ratio = gr.Number(label="聚类模型/特征检索混合比例0-1之间0即不启用聚类/特征检索。使用聚类/特征检索能提升音色相似度但会导致咬字下降如果使用建议0.5左右)", value=0)
slice_db = gr.Number(label="切片阈值", value=-40)