Merge pull request #331 from svc-develop-team/add-multi-GPU-preprocess-support

Add multi gpu preprocess support
This commit is contained in:
YuriHead 2023-07-22 20:17:19 +08:00 committed by GitHub
commit 59644ac428
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 9 deletions

View File

@ -258,6 +258,15 @@ Add `--vol_aug` if you want to enable loudness embedding:
python preprocess_flist_config.py --speech_encoder vec768l12 --vol_aug
```
**Speed Up preprocess**
If your dataset is pretty large,you can increase the param `--num_processes` like that:
```shell
python preprocess_flist_config.py --speech_encoder vec768l12 --vol_aug --num_processes 8
```
All the worker will be assigned to different GPU if you have more than one GPUs.
After enabling loudness embedding, the trained model will match the loudness of the input source; otherwise, it will match the loudness of the training set.
#### You can modify some parameters in the generated config.json and diffusion.yaml

View File

@ -260,6 +260,12 @@ wavlmbase+
python preprocess_flist_config.py --speech_encoder vec768l12 --vol_aug
```
**加速预处理**
如若您的数据集比较大,可以尝试添加`--num_processes`参数:
```shell
python preprocess_flist_config.py --speech_encoder vec768l12 --vol_aug --num_processes 8
```
所有的Workers会被自动分配到多个GPU上如果您有多个GPU的话
使用后训练出的模型将匹配到输入源响度,否则为训练集响度。
#### 此时可以在生成的 config.json 与 diffusion.yaml 修改部分参数

View File

@ -1,6 +1,6 @@
import argparse
import logging
import multiprocessing
import torch.multiprocessing as mp
import os
import random
from concurrent.futures import ProcessPoolExecutor
@ -106,10 +106,14 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
print("Loading speech encoder for content...")
device = "cuda" if torch.cuda.is_available() else "cpu"
rank = mp.current_process()._identity
rank = rank[0] if len(rank) > 0 else 0
if torch.cuda.is_available():
gpu_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{gpu_id}")
print("Rank {rank} uses device {device}")
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
print("Loaded speech encoder.")
for filename in tqdm(file_chunk):
process_one(filename, hmodel, f0p, diff, mel_extractor)
@ -121,7 +125,6 @@ def parallel_process(filenames, num_processes, f0p, diff, mel_extractor):
end = int((i + 1) * len(filenames) / num_processes)
file_chunk = filenames[start:end]
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor))
for task in tqdm(tasks):
task.result()
@ -139,7 +142,6 @@ if __name__ == "__main__":
parser.add_argument(
'--num_processes', type=int, default=1, help='You are advised to set the number of processes to the same as the number of CPU cores'
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
f0p = args.f0_predictor
print(speech_encoder)
@ -148,16 +150,16 @@ if __name__ == "__main__":
if args.use_diff:
print("use_diff")
print("Loading Mel Extractor...")
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = device)
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = "cuda:0")
print("Loaded Mel Extractor.")
else:
mel_extractor = None
filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10]
shuffle(filenames)
multiprocessing.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force=True)
num_processes = args.num_processes
if num_processes == 0:
num_processes = os.cpu_count()
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor)
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor)

View File

@ -43,7 +43,6 @@ def normalize_f0(f0, x_mask, uv, random_scale=True):
if torch.isnan(f0_norm).any():
exit(0)
return f0_norm * x_mask
def plot_data_to_numpy(x, y):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG: