Updata Alone Shallow Diffusion Train

This commit is contained in:
ylzz1997 2023-06-16 01:08:20 +08:00
parent 3cb2013057
commit 2def595e02
9 changed files with 102 additions and 40 deletions

View File

@ -80,7 +80,7 @@ After conducting tests, we believe that the project runs stably on `Python 3.8.9
- Place it under the `pretrain` directory
Or download the following ContentVec, which is only 199MB in size but has the same effect:
- contentvec [hubert_base.pt](https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt)
- ContentVec: [hubert_base.pt](https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt)
- Change the file name to `checkpoint_best_legacy_500.pt` and place it in the `pretrain` directory
```shell
@ -90,7 +90,7 @@ wget -P pretrain/ http://obs.cstcloud.cn/share/obs/sankagenkeshi/checkpoint_best
```
##### **2. If hubertsoft is used as the speech encoder**
- soft vc hubert[hubert-soft-0d54a1f4.pt](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt)
- soft vc hubert: [hubert-soft-0d54a1f4.pt](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt)
- Place it under the `pretrain` directory
##### **3. If whisper-ppg as the encoder**
@ -155,7 +155,7 @@ If you are using the `NSF-HIFIGAN enhancer` or `shallow diffusion`, you will nee
wget -P pretrain/ https://github.com/openvpi/vocoders/releases/download/nsf-hifigan-v1/nsf_hifigan_20221211.zip
\unzip -od pretrain/nsf_hifigan pretrain/nsf_hifigan_20221211.zip
# Alternatively, you can manually download and place it in the pretrain/nsf_hifigan directory
# URLhttps://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1
# URL: https://github.com/openvpi/vocoders/releases/tag/nsf-hifigan-v1
```
## 📊 Dataset Preparation
@ -245,13 +245,27 @@ After enabling loudness embedding, the trained model will match the loudness of
#### You can modify some parameters in the generated config.json and diffusion.yaml
##### config.json
* `keep_ckpts`: Keep the last `keep_ckpts` models during training. Set to `0` will keep them all. Default is `3`.
* `all_in_mem`, `cache_all_data`: Load all dataset to RAM. It can be enabled when the disk IO of some platforms is too low and the system memory is **much larger** than your dataset.
* `all_in_mem`: Load all dataset to RAM. It can be enabled when the disk IO of some platforms is too low and the system memory is **much larger** than your dataset.
* `batch_size`: The amount of data loaded to the GPU for a single training session can be adjusted to a size lower than the video memory capacity.
* `vocoder_name` : Select a vocoder. The default is `nsf-hifigan`.
* `vocoder_name`: Select a vocoder. The default is `nsf-hifigan`.
##### diffusion.yaml
* `cache_all_data`: Load all dataset to RAM. It can be enabled when the disk IO of some platforms is too low and the system memory is **much larger** than your dataset.
* `duration`: The duration of the audio slicing during training, can be adjusted according to the size of the video memory, **Note: this value must be less than the minimum time of the audio in the training set!**
* `batch_size`: The amount of data loaded to the GPU for a single training session can be adjusted to a size lower than the video memory capacity.
* `timesteps`: The total number of steps in the diffusion model, which defaults to 1000.
* `k_step_max`: Training can only train 'k_step_max' step diffusion to save training time, note that the value must be less than 'timesteps', 0 is to train the entire diffusion model, **Note: if you do not train the entire diffusion model will not be able to use only_diffusion!**
##### **List of Vocoders**
@ -289,6 +303,12 @@ After completing the above steps, the dataset directory will contain the preproc
## 🏋️‍♀️ Training
### Sovits Model
```shell
python train.py -c configs/config.json -m 44k
```
### Diffusion Model (optional)
If the shallow diffusion function is needed, the diffusion model needs to be trained. The diffusion model training method is as follows:
@ -297,12 +317,6 @@ If the shallow diffusion function is needed, the diffusion model needs to be tra
python train_diff.py -c configs/diffusion.yaml
```
### Sovits Model
```shell
python train.py -c configs/config.json -m 44k
```
After the model training, the model file is saved in the directory `logs/44k`, and the diffusion model is stored under `logs/44k/diffusion`
## 🤖 Inference
@ -331,15 +345,15 @@ Optional parameters: see the next section
- `-eh` | `--enhance`: Whether to use NSF_HIFIGAN enhancer, this option has certain effect on sound quality enhancement for some models with few training sets, but has negative effect on well-trained models, so it is turned off by default.
- `-shd` | `--shallow_diffusion`: Whether to use shallow diffusion, which can solve some electrical sound problems after use. This option is turned off by default. When this option is enabled, NSF_HIFIGAN intensifier will be disabled
- `-usm` | `--use_spk_mix`: whether to use dynamic voice/merge their role
- `-lea` | `--loudness_envelope_adjustment`The input source loudness envelope replaces the output loudness envelope fusion ratio. The closer to 1, the more the output loudness envelope is used
- `-fr` | `--feature_retrieval`Whether to use feature retrieval? If clustering model is used, it will be disabled, and cm and cr parameters will become the index path and mixing ratio of feature retrieval
- `-lea` | `--loudness_envelope_adjustment`: The input source loudness envelope replaces the output loudness envelope fusion ratio. The closer to 1, the more the output loudness envelope is used
- `-fr` | `--feature_retrieval`: Whether to use feature retrieval? If clustering model is used, it will be disabled, and cm and cr parameters will become the index path and mixing ratio of feature retrieval
Shallow diffusion settings:
- `-dm` | `--diffusion_model_path`: Diffusion model path
- `-dc` | `--diffusion_config_path`: Diffusion model profile path
- `-ks` | `--k_step`: The larger the number of diffusion steps, the closer it is to the result of the diffusion model. The default is 100
- `-od` | `--only_diffusion`: Only diffusion mode, which does not load the sovits model to the diffusion model inference
- `-se` | `--second_encoding`Secondary encoding, secondary coding of the original audio before shallow diffusion, mystery options, sometimes good, sometimes bad
- `-se` | `--second_encoding`: Secondary encoding, secondary coding of the original audio before shallow diffusion, mystery options, sometimes good, sometimes bad
### Attention
@ -372,8 +386,8 @@ The existing steps before clustering do not need to be changed. All you need to
Introduction: As with the clustering scheme, the timbre leakage can be reduced, the character is slightly better than clustering, but it will reduce the reasoning speed, using the fusion method, can linearly control the proportion of feature retrieval and non-feature retrieval.
- Training process
First, it needs to be executed after generating hubert and f0
- Training process:
First, it needs to be executed after generating hubert and f0:
```shell
python train_index.py -c configs/config.json
@ -381,7 +395,7 @@ python train_index.py -c configs/config.json
The output of the model will be in `logs/44k/feature_and_index.pkl`
- Inference process
- Inference process:
- The `--feature_retrieval` needs to be formulated first, and the clustering mode automatically switches to the feature retrieval mode.
- Specify `cluster_model_path` in `inference_main.py`.
- Specify `cluster_infer_ratio` in `inference_main.py`, where `0` means not using feature retrieval at all, `1` means only using feature retrieval, and usually `0.5` is sufficient.

View File

@ -245,14 +245,28 @@ python preprocess_flist_config.py --speech_encoder vec768l12 --vol_aug
#### 此时可以在生成的config.json与diffusion.yaml修改部分参数
##### config.json
* `keep_ckpts`:训练时保留最后几个模型,`0`为保留所有,默认只保留最后`3`个
* `all_in_mem`,`cache_all_data`加载所有数据集到内存中某些平台的硬盘IO过于低下、同时内存容量 **远大于** 数据集体积时可以启用
* `all_in_mem`加载所有数据集到内存中某些平台的硬盘IO过于低下、同时内存容量 **远大于** 数据集体积时可以启用
* `batch_size`单次训练加载到GPU的数据量调整到低于显存容量的大小即可
* `vocoder_name` : 选择一种声码器,默认为`nsf-hifigan`.
##### diffusion.yaml
* `cache_all_data`加载所有数据集到内存中某些平台的硬盘IO过于低下、同时内存容量 **远大于** 数据集体积时可以启用
* `duration`:训练时音频切片时长,可根据显存大小调整,**注意,该值必须小于训练集内音频的最短时间!**
* `batch_size`单次训练加载到GPU的数据量调整到低于显存容量的大小即可
* `timesteps` : 扩散模型总步数默认为1000.
* `k_step_max` : 训练时可仅训练`k_step_max`步扩散以节约训练时间,注意,该值必须小于`timesteps`0为训练全部整个扩散模型**注意,如果不训练整个扩散模型将无法使用仅扩散推理!**
##### **声码器列表**
```
@ -289,6 +303,12 @@ python preprocess_hubert_f0.py --f0_predictor dio --use_diff
## 🏋️‍♀️ 训练
### 主模型训练
```shell
python train.py -c configs/config.json -m 44k
```
### 扩散模型(可选)
尚若需要浅扩散功能,需要训练扩散模型,扩散模型训练方法为:
@ -297,12 +317,6 @@ python preprocess_hubert_f0.py --f0_predictor dio --use_diff
python train_diff.py -c configs/diffusion.yaml
```
### 主模型训练
```shell
python train.py -c configs/config.json -m 44k
```
模型训练结束后,模型文件保存在`logs/44k`目录下,扩散模型在`logs/44k/diffusion`下
## 🤖 推理

View File

@ -17,7 +17,9 @@ model:
n_layers: 20
n_chans: 512
n_hidden: 256
use_pitch_aug: true
use_pitch_aug: true
timesteps : 1000
k_step_max: 0 # must <= timesteps, If it is 0, train all
n_spk: 1 # max number of different speakers
device: cuda
vocoder:

View File

@ -67,6 +67,7 @@ class GaussianDiffusion(nn.Module):
max_beta=0.02,
spec_min=-12,
spec_max=2):
super().__init__()
self.denoise_fn = denoise_fn
self.out_dims = out_dims
@ -78,7 +79,7 @@ class GaussianDiffusion(nn.Module):
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.k_step = k_step
self.k_step = k_step if k_step>0 and k_step<timesteps else timesteps
self.noise_list = deque(maxlen=4)

View File

@ -33,7 +33,9 @@ def load_model_vocoder(
128,
args.model.n_layers,
args.model.n_chans,
args.model.n_hidden)
args.model.n_hidden,
args.model.timesteps,
args.model.k_step_max)
print(' [Loading] ' + model_path)
ckpt = torch.load(model_path, map_location=torch.device(device))
@ -52,8 +54,11 @@ class Unit2Mel(nn.Module):
out_dims=128,
n_layers=20,
n_chans=384,
n_hidden=256):
n_hidden=256,
timesteps=1000,
k_step_max=1000):
super().__init__()
self.unit_embed = nn.Linear(input_channel, n_hidden)
self.f0_embed = nn.Linear(1, n_hidden)
self.volume_embed = nn.Linear(1, n_hidden)
@ -64,9 +69,13 @@ class Unit2Mel(nn.Module):
self.n_spk = n_spk
if n_spk is not None and n_spk > 1:
self.spk_embed = nn.Embedding(n_spk, n_hidden)
self.timesteps = timesteps if timesteps is not None else 1000
self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max<self.timesteps else self.timesteps
# diffusion
self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden)
self.decoder = GaussianDiffusion(out_dims, n_layers, n_chans, n_hidden,self.timesteps,self.k_step_max)
self.hidden_size = n_hidden
self.speaker_map = torch.zeros((self.n_spk,1,1,n_hidden))

View File

@ -40,10 +40,12 @@ def test(args, model, vocoder, loader_test, saver):
data['f0'],
data['volume'],
data['spk_id'],
gt_spec=None,
gt_spec=None if model.k_step_max == model.timesteps else data['mel'],
infer=True,
infer_speedup=args.infer.speedup,
method=args.infer.method)
method=args.infer.method,
k_step=model.k_step_max
)
signal = vocoder.infer(mel, data['f0'])
ed_time = time.time()
@ -62,7 +64,8 @@ def test(args, model, vocoder, loader_test, saver):
data['volume'],
data['spk_id'],
gt_spec=data['mel'],
infer=False)
infer=False,
k_step=model.k_step_max)
test_loss += loss.item()
# log mel
@ -121,11 +124,11 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
# forward
if dtype == torch.float32:
loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False)
aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max)
else:
with autocast(device_type=args.device, dtype=dtype):
loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False)
aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max)
# handle nan loss
if torch.isnan(loss):

View File

@ -39,7 +39,10 @@ def load_model_vocoder(
vocoder.dimension,
args.model.n_layers,
args.model.n_chans,
args.model.n_hidden)
args.model.n_hidden,
args.model.timesteps,
args.model.k_step_max
)
print(' [Loading] ' + model_path)
ckpt = torch.load(model_path, map_location=torch.device(device))
@ -58,7 +61,10 @@ class Unit2Mel(nn.Module):
out_dims=128,
n_layers=20,
n_chans=384,
n_hidden=256):
n_hidden=256,
timesteps=1000,
k_step_max=1000
):
super().__init__()
self.unit_embed = nn.Linear(input_channel, n_hidden)
self.f0_embed = nn.Linear(1, n_hidden)
@ -71,9 +77,12 @@ class Unit2Mel(nn.Module):
if n_spk is not None and n_spk > 1:
self.spk_embed = nn.Embedding(n_spk, n_hidden)
self.timesteps = timesteps if timesteps is not None else 1000
self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max<self.timesteps else self.timesteps
self.n_hidden = n_hidden
# diffusion
self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims)
self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden),timesteps=self.timesteps,k_step=self.k_step_max, out_dims=out_dims)
self.input_channel = input_channel
def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
@ -124,6 +133,12 @@ class Unit2Mel(nn.Module):
dict of B x n_frames x feat
'''
if not self.training and gt_spec is not None and k_step>self.k_step_max:
raise Exception("The shallow diffusion k_step is greater than the maximum diffusion k_step(k_step_max)!")
if not self.training and gt_spec is None and self.k_step_max!=self.timesteps:
raise Exception("This model can only be used for shallow diffusion and can not infer alone!")
x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
if self.n_spk is not None and self.n_spk > 1:
if spk_mix_dict is not None:

View File

@ -41,8 +41,12 @@ if __name__ == '__main__':
vocoder.dimension,
args.model.n_layers,
args.model.n_chans,
args.model.n_hidden)
args.model.n_hidden,
args.model.timesteps,
args.model.k_step_max
)
print(f' > INFO: now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}')
# load parameters
optimizer = torch.optim.AdamW(model.parameters())

View File

@ -6,7 +6,7 @@ class SpeechEncoder(object):
def encoder(self,wav):
'''
input: wav:[batchsize,signal_length]
input: wav:[signal_length]
output: embedding:[batchsize,hidden_dim,wav_frame]
'''
pass