Update preprocess_hubert_f0.py

This commit is contained in:
Stardust·减 2023-07-22 14:30:54 +08:00 committed by GitHub
parent 225f7fdfa5
commit a915e47024
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 9 deletions

View File

@ -1,6 +1,6 @@
import argparse
import logging
import multiprocessing
import torch.multiprocessing as mp
import os
import random
from concurrent.futures import ProcessPoolExecutor
@ -106,10 +106,14 @@ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
print("Loading speech encoder for content...")
device = "cuda" if torch.cuda.is_available() else "cpu"
rank = mp.current_process()._identity
rank = rank[0] if len(rank) > 0 else 0
if torch.cuda.is_available():
gpu_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{gpu_id}")
print("Rank {rank} uses device {device}")
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
print("Loaded speech encoder.")
for filename in tqdm(file_chunk):
process_one(filename, hmodel, f0p, diff, mel_extractor)
@ -120,8 +124,7 @@ def parallel_process(filenames, num_processes, f0p, diff, mel_extractor):
start = int(i * len(filenames) / num_processes)
end = int((i + 1) * len(filenames) / num_processes)
file_chunk = filenames[start:end]
tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor))
tasks.append(executor.map(process_batch, file_chunk, f0p, diff, mel_extractor))
for task in tqdm(tasks):
task.result()
@ -139,7 +142,6 @@ if __name__ == "__main__":
parser.add_argument(
'--num_processes', type=int, default=1, help='You are advised to set the number of processes to the same as the number of CPU cores'
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
f0p = args.f0_predictor
print(speech_encoder)
@ -148,16 +150,16 @@ if __name__ == "__main__":
if args.use_diff:
print("use_diff")
print("Loading Mel Extractor...")
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = device)
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = "cuda:0")
print("Loaded Mel Extractor.")
else:
mel_extractor = None
filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10]
shuffle(filenames)
multiprocessing.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force=True)
num_processes = args.num_processes
if num_processes == 0:
num_processes = os.cpu_count()
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor)
parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor)