From c092d25716c73d7446ed01f7bcafc6bb93a668f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=AC=AC=E7=B4=97=E7=89=B9?= <66856838+Miuzarte@users.noreply.github.com> Date: Sun, 12 Mar 2023 15:59:29 +0800 Subject: [PATCH] Increase convenience for colab training --- configs/config.json | 106 ---------------- configs_template/config_template.json | 106 ++++++++++++++++ preprocess_flist_config.py | 166 +++++++++++++------------- 3 files changed, 189 insertions(+), 189 deletions(-) create mode 100644 configs_template/config_template.json diff --git a/configs/config.json b/configs/config.json index 550059a..e69de29 100644 --- a/configs/config.json +++ b/configs/config.json @@ -1,106 +0,0 @@ -{ - "train": { - "log_interval": 50, - "eval_interval": 1000, - "seed": 1234, - "port": 8001, - "epochs": 10000, - "learning_rate": 0.0002, - "betas": [ - 0.8, - 0.99 - ], - "eps": 1e-09, - "batch_size": 6, - "accumulation_steps": 1, - "fp16_run": false, - "lr_decay": 0.998, - "segment_size": 10240, - "init_lr_ratio": 1, - "warmup_epochs": 0, - "c_mel": 45, - "keep_ckpts":4 - }, - "data": { - "data_dir": "dataset", - "dataset_type": "SingDataset", - "collate_type": "SingCollate", - "training_filelist": "filelists/train.txt", - "validation_filelist": "filelists/val.txt", - "max_wav_value": 32768.0, - "sampling_rate": 44100, - "n_fft": 2048, - "fmin": 0, - "fmax": 22050, - "hop_length": 512, - "win_size": 2048, - "acoustic_dim": 80, - "c_dim": 256, - "min_level_db": -115, - "ref_level_db": 20, - "min_db": -115, - "max_abs_value": 4.0, - "n_speakers": 200 - }, - "model": { - "hidden_channels": 192, - "spk_channels": 192, - "filter_channels": 768, - "n_heads": 2, - "n_layers": 4, - "kernel_size": 3, - "p_dropout": 0.1, - "prior_hidden_channels": 192, - "prior_filter_channels": 768, - "prior_n_heads": 2, - "prior_n_layers": 4, - "prior_kernel_size": 3, - "prior_p_dropout": 0.1, - "resblock": "1", - "use_spectral_norm": false, - "resblock_kernel_sizes": [ - 3, - 7, - 11 - ], - "resblock_dilation_sizes": [ - [ - 1, - 3, - 5 - ], - [ - 1, - 3, - 5 - ], - [ - 1, - 3, - 5 - ] - ], - "upsample_rates": [ - 8, - 8, - 4, - 2 - ], - "upsample_initial_channel": 256, - "upsample_kernel_sizes": [ - 16, - 16, - 8, - 4 - ], - "n_harmonic": 64, - "n_bands": 65 - }, - "spk": { - "jishuang": 0, - "huiyu": 1, - "nen": 2, - "paimon": 3, - "yunhao": 4 - } -} \ No newline at end of file diff --git a/configs_template/config_template.json b/configs_template/config_template.json new file mode 100644 index 0000000..edd0af5 --- /dev/null +++ b/configs_template/config_template.json @@ -0,0 +1,106 @@ +{ + "train": { + "log_interval": 50, + "eval_interval": 1000, + "seed": 1234, + "port": 8001, + "epochs": 10000, + "learning_rate": 0.0002, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-09, + "batch_size": 6, + "accumulation_steps": 1, + "fp16_run": false, + "lr_decay": 0.998, + "segment_size": 10240, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "keep_ckpts":4 + }, + "data": { + "data_dir": "dataset", + "dataset_type": "SingDataset", + "collate_type": "SingCollate", + "training_filelist": "filelists/train.txt", + "validation_filelist": "filelists/val.txt", + "max_wav_value": 32768.0, + "sampling_rate": 44100, + "n_fft": 2048, + "fmin": 0, + "fmax": 22050, + "hop_length": 512, + "win_size": 2048, + "acoustic_dim": 80, + "c_dim": 256, + "min_level_db": -115, + "ref_level_db": 20, + "min_db": -115, + "max_abs_value": 4.0, + "n_speakers": 200 + }, + "model": { + "hidden_channels": 192, + "spk_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 4, + "kernel_size": 3, + "p_dropout": 0.1, + "prior_hidden_channels": 192, + "prior_filter_channels": 768, + "prior_n_heads": 2, + "prior_n_layers": 4, + "prior_kernel_size": 3, + "prior_p_dropout": 0.1, + "resblock": "1", + "use_spectral_norm": false, + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 4, + 2 + ], + "upsample_initial_channel": 256, + "upsample_kernel_sizes": [ + 16, + 16, + 8, + 4 + ], + "n_harmonic": 64, + "n_bands": 65 + }, + "spk": { + "jishuang": 0, + "huiyu": 1, + "nen": 2, + "paimon": 3, + "yunhao": 4 + } +} \ No newline at end of file diff --git a/preprocess_flist_config.py b/preprocess_flist_config.py index ff5c969..a83e637 100644 --- a/preprocess_flist_config.py +++ b/preprocess_flist_config.py @@ -1,83 +1,83 @@ -import os -import argparse -import re - -from tqdm import tqdm -from random import shuffle -import json -import wave - -config_template = json.load(open("configs/config.json")) - -pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$') - -def get_wav_duration(file_path): - with wave.open(file_path, 'rb') as wav_file: - # 获取音频帧数 - n_frames = wav_file.getnframes() - # 获取采样率 - framerate = wav_file.getframerate() - # 计算时长(秒) - duration = n_frames / float(framerate) - return duration - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list") - parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list") - parser.add_argument("--test_list", type=str, default="./filelists/test.txt", help="path to test list") - parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir") - args = parser.parse_args() - - train = [] - val = [] - test = [] - idx = 0 - spk_dict = {} - spk_id = 0 - for speaker in tqdm(os.listdir(args.source_dir)): - spk_dict[speaker] = spk_id - spk_id += 1 - wavs = ["/".join([args.source_dir, speaker, i]) for i in os.listdir(os.path.join(args.source_dir, speaker))] - new_wavs = [] - for file in wavs: - if not file.endswith("wav"): - continue - if not pattern.match(file): - print(f"warning:文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)") - if get_wav_duration(file) < 0.3: - print("skip too short audio:", file) - continue - new_wavs.append(file) - wavs = new_wavs - shuffle(wavs) - train += wavs[2:-2] - val += wavs[:2] - test += wavs[-2:] - - shuffle(train) - shuffle(val) - shuffle(test) - - print("Writing", args.train_list) - with open(args.train_list, "w") as f: - for fname in tqdm(train): - wavpath = fname - f.write(wavpath + "\n") - - print("Writing", args.val_list) - with open(args.val_list, "w") as f: - for fname in tqdm(val): - wavpath = fname - f.write(wavpath + "\n") - - print("Writing", args.test_list) - with open(args.test_list, "w") as f: - for fname in tqdm(test): - wavpath = fname - f.write(wavpath + "\n") - - config_template["spk"] = spk_dict - print("Writing configs/config.json") - with open("configs/config.json", "w") as f: - json.dump(config_template, f, indent=2) +import os +import argparse +import re + +from tqdm import tqdm +from random import shuffle +import json +import wave + +config_template = json.load(open("configs_template/config_template.json")) + +pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$') + +def get_wav_duration(file_path): + with wave.open(file_path, 'rb') as wav_file: + # 获取音频帧数 + n_frames = wav_file.getnframes() + # 获取采样率 + framerate = wav_file.getframerate() + # 计算时长(秒) + duration = n_frames / float(framerate) + return duration + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list") + parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list") + parser.add_argument("--test_list", type=str, default="./filelists/test.txt", help="path to test list") + parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir") + args = parser.parse_args() + + train = [] + val = [] + test = [] + idx = 0 + spk_dict = {} + spk_id = 0 + for speaker in tqdm(os.listdir(args.source_dir)): + spk_dict[speaker] = spk_id + spk_id += 1 + wavs = ["/".join([args.source_dir, speaker, i]) for i in os.listdir(os.path.join(args.source_dir, speaker))] + new_wavs = [] + for file in wavs: + if not file.endswith("wav"): + continue + if not pattern.match(file): + print(f"warning:文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)") + if get_wav_duration(file) < 0.3: + print("skip too short audio:", file) + continue + new_wavs.append(file) + wavs = new_wavs + shuffle(wavs) + train += wavs[2:-2] + val += wavs[:2] + test += wavs[-2:] + + shuffle(train) + shuffle(val) + shuffle(test) + + print("Writing", args.train_list) + with open(args.train_list, "w") as f: + for fname in tqdm(train): + wavpath = fname + f.write(wavpath + "\n") + + print("Writing", args.val_list) + with open(args.val_list, "w") as f: + for fname in tqdm(val): + wavpath = fname + f.write(wavpath + "\n") + + print("Writing", args.test_list) + with open(args.test_list, "w") as f: + for fname in tqdm(test): + wavpath = fname + f.write(wavpath + "\n") + + config_template["spk"] = spk_dict + print("Writing configs/config.json") + with open("configs/config.json", "w") as f: + json.dump(config_template, f, indent=2)