并行执行预处理,处理速度或将提升,整理部分代码 (#230)

* 使用多线程来进行预处理,速度或将大幅提升,使用方法为 --use_thread <线程数>

* 并行执行预处理,处理速度或将提升,整理部分代码

* 修正kl_loss计算公式

* resample使用多进程,preprocess_hubert_f0使用旧的并行方式(尝试共享模型失败)
This commit is contained in:
ZSCharlie 2023-06-20 03:38:13 +08:00 committed by YuriHead
parent d8e30c40f7
commit 5159c3b543
3 changed files with 106 additions and 64 deletions

View File

@ -55,7 +55,8 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
z_mask = z_mask.float()
#print(logs_p)
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
# kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl += 0.5 * (torch.exp(2.*logs_q)+(z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l

View File

@ -1,27 +1,24 @@
import math
import multiprocessing
import os
import argparse
from random import shuffle
import random
import utils
import torch
import random
import librosa
import logging
import argparse
import multiprocessing
import numpy as np
import diffusion.logger.utils as du
from glob import glob
from tqdm import tqdm
from random import shuffle
from diffusion.vocoder import Vocoder
from concurrent.futures import ProcessPoolExecutor
from modules.mel_processing import spectrogram_torch
import json
import utils
import logging
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
import diffusion.logger.utils as du
from diffusion.vocoder import Vocoder
import librosa
import numpy as np
hps = utils.get_hparams_from_file("configs/config.json")
dconfig = du.load_config("configs/diffusion.yaml")
sampling_rate = hps.data.sampling_rate
@ -42,7 +39,7 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
wav16k = torch.from_numpy(wav16k).to(device)
c = hmodel.encoder(wav16k)
torch.save(c.cpu(), soft_path)
f0_path = filename + ".f0.npy"
if not os.path.exists(f0_path):
f0_predictor = utils.get_f0_predictor(f0p,sampling_rate=sampling_rate, hop_length=hop_length,device=None,threshold=0.05)
@ -50,8 +47,8 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
wav
)
np.save(f0_path, np.asanyarray((f0,uv),dtype=object))
spec_path = filename.replace(".wav", ".spec.pt")
if not os.path.exists(spec_path):
# Process spectrogram
@ -64,9 +61,9 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
sr, hps.data.sampling_rate
)
)
#audio_norm = audio / hps.data.max_wav_value
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
@ -106,28 +103,39 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
if not os.path.exists(aug_vol_path):
np.save(aug_vol_path,aug_vol.to('cpu').numpy())
def process_batch(filenames,f0p,diff=False,mel_extractor=None):
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
print("Loading speech encoder for content...")
device = "cuda" if torch.cuda.is_available() else "cpu"
hmodel = utils.get_speech_encoder(speech_encoder,device=device)
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
print("Loaded speech encoder.")
for filename in tqdm(filenames):
process_one(filename, hmodel,f0p,diff,mel_extractor)
for filename in tqdm(file_chunk):
process_one(filename, hmodel, f0p, diff, mel_extractor)
def parallel_process(filenames, num_processes, f0p, diff, mel_extractor):
with ProcessPoolExecutor(max_workers=num_processes) as executor:
tasks = []
for i in range(num_processes):
start = int(i * len(filenames) / num_processes)
end = int((i + 1) * len(filenames) / num_processes)
file_chunk = filenames[start:end]
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor))
for task in tqdm(tasks):
task.result()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--in_dir", type=str, default="dataset/44k", help="path to input dir"
)
parser.add_argument(
parser.add_argument(
'--use_diff',action='store_true', help='Whether to use the diffusion model'
)
parser.add_argument(
parser.add_argument(
'--f0_predictor', type=str, default="dio", help='Select F0 predictor, can select crepe,pm,dio,harvest, default pm(note: crepe is original F0 using mean filter)'
)
parser.add_argument(
parser.add_argument(
'--num_processes', type=int, default=1, help='You are advised to set the number of processes to the same as the number of CPU cores'
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -135,6 +143,7 @@ if __name__ == "__main__":
f0p = args.f0_predictor
print(speech_encoder)
print(f0p)
print(args.use_diff)
if args.use_diff:
print("use_diff")
print("Loading Mel Extractor...")
@ -145,15 +154,9 @@ if __name__ == "__main__":
filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10]
shuffle(filenames)
multiprocessing.set_start_method("spawn", force=True)
num_processes = args.num_processes
chunk_size = int(math.ceil(len(filenames) / num_processes))
chunks = [
filenames[i : i + chunk_size] for i in range(0, len(filenames), chunk_size)
]
print([len(c) for c in chunks])
processes = [
multiprocessing.Process(target=process_batch, args=(chunk,f0p,args.use_diff,mel_extractor)) for chunk in chunks
]
for p in processes:
p.start()
if num_processes == 0:
num_processes = os.cpu_count()
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor)

View File

@ -2,34 +2,77 @@ import os
import argparse
import librosa
import numpy as np
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from multiprocessing import Pool, cpu_count
from scipy.io import wavfile
from tqdm import tqdm
def load_wav(wav_path):
return librosa.load(wav_path, sr=None)
def trim_wav(wav, top_db=40):
return librosa.effects.trim(wav, top_db=top_db)
def normalize_peak(wav, threshold=1.0):
peak = np.abs(wav).max()
if peak > threshold:
wav = 0.98 * wav / peak
return wav
def resample_wav(wav, sr, target_sr):
return librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
def save_wav_to_path(wav, save_path, sr):
wavfile.write(
save_path,
sr,
(wav * np.iinfo(np.int16).max).astype(np.int16)
)
def process(item):
spkdir, wav_name, args = item
# speaker 's5', 'p280', 'p315' are excluded,
speaker = spkdir.replace("\\", "/").split("/")[-1]
wav_path = os.path.join(args.in_dir, speaker, wav_name)
if os.path.exists(wav_path) and '.wav' in wav_path:
os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True)
wav, sr = librosa.load(wav_path, sr=None)
wav, _ = librosa.effects.trim(wav, top_db=40)
peak = np.abs(wav).max()
if peak > 1.0:
wav = 0.98 * wav / peak
wav2 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr2)
if not args.skip_loudnorm:
wav2 /= max(wav2.max(), -wav2.min())
save_name = wav_name
save_path2 = os.path.join(args.out_dir2, speaker, save_name)
wavfile.write(
save_path2,
args.sr2,
(wav2 * np.iinfo(np.int16).max).astype(np.int16)
)
wav, sr = load_wav(wav_path)
wav, _ = trim_wav(wav)
wav = normalize_peak(wav)
resampled_wav = resample_wav(wav, sr, args.sr2)
if not args.skip_loudnorm:
resampled_wav /= max(resampled_wav.max(), -resampled_wav.min())
save_path2 = os.path.join(args.out_dir2, speaker, wav_name)
save_wav_to_path(resampled_wav, save_path2, args.sr2)
# def process_all_speakers(speakers, args):
# process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1)
# with ThreadPoolExecutor(max_workers=process_count) as executor:
# for speaker in speakers:
# spk_dir = os.path.join(args.in_dir, speaker)
# if os.path.isdir(spk_dir):
# print(spk_dir)
# futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")]
# for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
# pass
# multi process
def process_all_speakers(speakers, args):
process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1)
with ProcessPoolExecutor(max_workers=process_count) as executor:
for speaker in speakers:
spk_dir = os.path.join(args.in_dir, speaker)
if os.path.isdir(spk_dir):
print(spk_dir)
futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")]
for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -38,12 +81,7 @@ if __name__ == "__main__":
parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir")
parser.add_argument("--skip_loudnorm", action="store_true", help="Skip loudness matching if you have done it")
args = parser.parse_args()
processs = 30 if cpu_count() > 60 else (cpu_count()-2 if cpu_count() > 4 else 1)
pool = Pool(processes=processs)
for speaker in os.listdir(args.in_dir):
spk_dir = os.path.join(args.in_dir, speaker)
if os.path.isdir(spk_dir):
print(spk_dir)
for _ in tqdm(pool.imap_unordered(process, [(spk_dir, i, args) for i in os.listdir(spk_dir) if i.endswith("wav")])):
pass
print(f"CPU count: {cpu_count()}")
speakers = os.listdir(args.in_dir)
process_all_speakers(speakers, args)