Merge pull request #365 from svc-develop-team/4.1-Stable

To Latest
This commit is contained in:
YuriHead 2023-08-02 00:43:07 +08:00 committed by GitHub
commit 39b0befef5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 35 deletions

View File

@ -136,7 +136,7 @@ class FCPE(nn.Module):
B, N, _ = y.size()
ci = self.cent_table[None, None, :].expand(B, N, -1)
confident, max_index = torch.max(y, dim=-1, keepdim=True)
local_argmax_index = torch.arange(0,8).to(max_index.device) + (max_index - 4)
local_argmax_index = torch.arange(0,9).to(max_index.device) + (max_index - 4)
local_argmax_index[local_argmax_index<0] = 0
local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1
ci_l = torch.gather(ci,-1,local_argmax_index)

View File

@ -13,14 +13,17 @@ import diffusion.logger.utils as du
pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
# 获取音频帧数
n_frames = wav_file.getnframes()
# 获取采样率
framerate = wav_file.getframerate()
# 计算时长(秒)
duration = n_frames / float(framerate)
return duration
try:
with wave.open(file_path, 'rb') as wav_file:
# 获取音频帧数
n_frames = wav_file.getnframes()
# 获取采样率
framerate = wav_file.getframerate()
# 计算时长(秒)
return n_frames / float(framerate)
except Exception as e:
logger.error(f"Reading {file_path}")
raise e
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -42,18 +45,25 @@ if __name__ == "__main__":
for speaker in tqdm(os.listdir(args.source_dir)):
spk_dict[speaker] = spk_id
spk_id += 1
wavs = ["/".join([args.source_dir, speaker, i]) for i in os.listdir(os.path.join(args.source_dir, speaker))]
new_wavs = []
for file in wavs:
if not file.endswith("wav"):
wavs = []
for file_name in os.listdir(os.path.join(args.source_dir, speaker)):
if not file_name.endswith("wav"):
continue
if not pattern.match(file):
logger.warning(f"文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
if get_wav_duration(file) < 0.3:
logger.info("Skip too short audio:" + file)
if file_name.startswith("."):
continue
new_wavs.append(file)
wavs = new_wavs
file_path = "/".join([args.source_dir, speaker, file_name])
if not pattern.match(file_name):
logger.warning("Detected non-ASCII file name: " + file_path)
if get_wav_duration(file_path) < 0.3:
logger.info("Skip too short audio: " + file_path)
continue
wavs.append(file_path)
shuffle(wavs)
train += wavs[2:]
val += wavs[:2]
@ -61,13 +71,13 @@ if __name__ == "__main__":
shuffle(train)
shuffle(val)
logger.info("Writing" + args.train_list)
logger.info("Writing " + args.train_list)
with open(args.train_list, "w") as f:
for fname in tqdm(train):
wavpath = fname
f.write(wavpath + "\n")
logger.info("Writing" + args.val_list)
logger.info("Writing " + args.val_list)
with open(args.val_list, "w") as f:
for fname in tqdm(val):
wavpath = fname

View File

@ -113,7 +113,7 @@ def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu")
logger.info(f"Rank {rank} uses device {device}")
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
logger.info(f"Loaded speech encoder for rank {rank}")
for filename in tqdm(file_chunk):
for filename in tqdm(file_chunk, position = rank):
process_one(filename, hmodel, f0p, device, diff, mel_extractor)
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device):
@ -124,7 +124,7 @@ def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device)
end = int((i + 1) * len(filenames) / num_processes)
file_chunk = filenames[start:end]
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device))
for task in tqdm(tasks):
for task in tqdm(tasks, position = 0):
task.result()
if __name__ == "__main__":
@ -149,10 +149,10 @@ if __name__ == "__main__":
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(speech_encoder)
logger.info("Using device: ", device)
logger.info("Using device: " + str(device))
logger.info("Using SpeechEncoder: " + speech_encoder)
logger.info("Using extractor: " + f0p)
logger.info("Using diff Mode: " + str( args.use_diff))
logger.info("Using diff Mode: " + str(args.use_diff))
if args.use_diff:
print("use_diff")

0
pretrain/__init__.py Normal file
View File

View File

@ -78,10 +78,9 @@
"#@markdown\n",
"\n",
"!git clone https://github.com/svc-develop-team/so-vits-svc -b 4.1-Stable\n",
"%pip uninstall -y torchdata torchtext\n",
"%pip install --upgrade pip setuptools numpy numba\n",
"%pip install pyworld praat-parselmouth fairseq tensorboardX torchcrepe librosa==0.9.1 pyyaml pynvml pyloudnorm faiss-gpu\n",
"%pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118\n",
"%cd /content/so-vits-svc\n",
"%pip install --upgrade pip setuptools\n",
"%pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118\n",
"exit()"
]
},
@ -163,8 +162,8 @@
"#@markdown Although the pretrained model generally does not cause any copyright problems, please pay attention to it. For example, ask the author in advance, or the author has indicated the feasible use in the description clearly.\n",
"\n",
"download_pretrained_model = True #@param {type:\"boolean\"}\n",
"D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\"] {allow-input: true}\n",
"G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\"] {allow-input: true}\n",
"D_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_D_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_D_320000.pth\"] {allow-input: true}\n",
"G_0_URL = \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\" #@param [\"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth\", \"https://huggingface.co/1asbgdh/sovits4.0-volemb-vec768/resolve/main/clean_G_320000.pth\", \"https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/vol_emb/clean_G_320000.pth\"] {allow-input: true}\n",
"\n",
"download_pretrained_diffusion_model = True #@param {type:\"boolean\"}\n",
"diff_model_URL = \"https://huggingface.co/datasets/ms903/Diff-SVC-refactor-pre-trained-model/resolve/main/fix_pitch_add_vctk_600k/model_0.pt\" #@param {type:\"string\"}\n",
@ -317,13 +316,17 @@
"#@markdown\n",
"%cd /content/so-vits-svc\n",
"\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n",
"use_diff = True #@param {type:\"boolean\"}\n",
"\n",
"import os\n",
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
"\n",
"if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n",
"\n",
"\n",
"diff_param = \"\"\n",
"if use_diff:\n",
" diff_param = \"--use_diff\"\n",
@ -624,7 +627,7 @@
"if auto_predict_f0:\n",
" apf = \" -a \"\n",
"\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\"]\n",
"f0_predictor = \"crepe\" #@param [\"crepe\", \"pm\", \"dio\", \"harvest\", \"rmvpe\", \"fcpe\"]\n",
"\n",
"enhance = False #@param {type:\"boolean\"}\n",
"ehc = \"\"\n",
@ -644,6 +647,9 @@
"if f0_predictor == \"rmvpe\" and not os.path.exists(\"./pretrain/rmvpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/rmvpe.pt -o pretrain/rmvpe.pt\n",
"\n",
"if f0_predictor == \"fcpe\" and not os.path.exists(\"./pretrain/fcpe.pt\"):\n",
" !curl -L https://huggingface.co/datasets/ylzz1997/rmvpe_pretrain_model/resolve/main/fcpe.pt -o pretrain/fcpe.pt\n",
"\n",
"if not os.path.exists(output):\n",
" !curl -L {url} -o {output}\n",
"\n",