2023-06-26 06:57:53 +00:00
import argparse
import logging
import multiprocessing
2023-03-10 10:11:04 +00:00
import os
2023-05-16 05:17:51 +00:00
import random
2023-06-26 06:57:53 +00:00
from concurrent . futures import ProcessPoolExecutor
from glob import glob
from random import shuffle
2023-06-19 19:38:13 +00:00
import librosa
import numpy as np
2023-06-26 06:57:53 +00:00
import torch
2023-03-10 10:11:04 +00:00
from tqdm import tqdm
2023-06-26 06:57:53 +00:00
import diffusion . logger . utils as du
import utils
2023-06-19 19:38:13 +00:00
from diffusion . vocoder import Vocoder
2023-03-24 05:00:14 +00:00
from modules . mel_processing import spectrogram_torch
2023-03-10 10:11:04 +00:00
2023-03-24 05:00:14 +00:00
logging . getLogger ( " numba " ) . setLevel ( logging . WARNING )
2023-05-16 05:17:51 +00:00
logging . getLogger ( " matplotlib " ) . setLevel ( logging . WARNING )
2023-03-10 10:11:04 +00:00
hps = utils . get_hparams_from_file ( " configs/config.json " )
2023-05-16 05:17:51 +00:00
dconfig = du . load_config ( " configs/diffusion.yaml " )
2023-03-10 10:11:04 +00:00
sampling_rate = hps . data . sampling_rate
hop_length = hps . data . hop_length
2023-05-14 07:22:20 +00:00
speech_encoder = hps [ " model " ] [ " speech_encoder " ]
2023-03-10 10:11:04 +00:00
2023-05-16 05:17:51 +00:00
def process_one ( filename , hmodel , f0p , diff = False , mel_extractor = None ) :
2023-03-10 10:11:04 +00:00
# print(filename)
wav , sr = librosa . load ( filename , sr = sampling_rate )
2023-05-16 05:17:51 +00:00
audio_norm = torch . FloatTensor ( wav )
audio_norm = audio_norm . unsqueeze ( 0 )
2023-05-16 17:10:43 +00:00
device = torch . device ( " cuda " if torch . cuda . is_available ( ) else " cpu " )
2023-05-16 05:17:51 +00:00
2023-03-10 10:11:04 +00:00
soft_path = filename + " .soft.pt "
if not os . path . exists ( soft_path ) :
wav16k = librosa . resample ( wav , orig_sr = sampling_rate , target_sr = 16000 )
2023-03-16 23:10:47 +00:00
wav16k = torch . from_numpy ( wav16k ) . to ( device )
2023-05-14 06:39:07 +00:00
c = hmodel . encoder ( wav16k )
2023-03-10 10:11:04 +00:00
torch . save ( c . cpu ( ) , soft_path )
2023-06-19 19:38:13 +00:00
2023-03-10 10:11:04 +00:00
f0_path = filename + " .f0.npy "
if not os . path . exists ( f0_path ) :
2023-05-14 06:39:07 +00:00
f0_predictor = utils . get_f0_predictor ( f0p , sampling_rate = sampling_rate , hop_length = hop_length , device = None , threshold = 0.05 )
2023-05-13 15:45:56 +00:00
f0 , uv = f0_predictor . compute_f0_uv (
2023-05-13 07:33:40 +00:00
wav
2023-03-24 05:00:14 +00:00
)
2023-05-13 15:45:56 +00:00
np . save ( f0_path , np . asanyarray ( ( f0 , uv ) , dtype = object ) )
2023-06-19 19:38:13 +00:00
2023-03-24 05:00:14 +00:00
spec_path = filename . replace ( " .wav " , " .spec.pt " )
if not os . path . exists ( spec_path ) :
# Process spectrogram
# The following code can't be replaced by torch.FloatTensor(wav)
# because load_wav_to_torch return a tensor that need to be normalized
if sr != hps . data . sampling_rate :
raise ValueError (
" {} SR doesn ' t match target {} SR " . format (
sr , hps . data . sampling_rate
)
)
2023-06-19 19:38:13 +00:00
2023-05-16 05:17:51 +00:00
#audio_norm = audio / hps.data.max_wav_value
2023-06-19 19:38:13 +00:00
2023-03-24 05:00:14 +00:00
spec = spectrogram_torch (
audio_norm ,
hps . data . filter_length ,
hps . data . sampling_rate ,
hps . data . hop_length ,
hps . data . win_length ,
center = False ,
)
spec = torch . squeeze ( spec , 0 )
torch . save ( spec , spec_path )
2023-05-28 13:47:32 +00:00
if diff or hps . model . vol_embedding :
2023-05-16 05:17:51 +00:00
volume_path = filename + " .vol.npy "
volume_extractor = utils . Volume_Extractor ( hop_length )
if not os . path . exists ( volume_path ) :
volume = volume_extractor . extract ( audio_norm )
np . save ( volume_path , volume . to ( ' cpu ' ) . numpy ( ) )
2023-05-28 13:47:32 +00:00
if diff :
2023-05-16 05:17:51 +00:00
mel_path = filename + " .mel.npy "
if not os . path . exists ( mel_path ) and mel_extractor is not None :
mel_t = mel_extractor . extract ( audio_norm . to ( device ) , sampling_rate )
mel = mel_t . squeeze ( ) . to ( ' cpu ' ) . numpy ( )
np . save ( mel_path , mel )
aug_mel_path = filename + " .aug_mel.npy "
aug_vol_path = filename + " .aug_vol.npy "
max_amp = float ( torch . max ( torch . abs ( audio_norm ) ) ) + 1e-5
max_shift = min ( 1 , np . log10 ( 1 / max_amp ) )
log10_vol_shift = random . uniform ( - 1 , max_shift )
keyshift = random . uniform ( - 5 , 5 )
if mel_extractor is not None :
aug_mel_t = mel_extractor . extract ( audio_norm * ( 10 * * log10_vol_shift ) , sampling_rate , keyshift = keyshift )
aug_mel = aug_mel_t . squeeze ( ) . to ( ' cpu ' ) . numpy ( )
aug_vol = volume_extractor . extract ( audio_norm * ( 10 * * log10_vol_shift ) )
if not os . path . exists ( aug_mel_path ) :
np . save ( aug_mel_path , np . asanyarray ( ( aug_mel , keyshift ) , dtype = object ) )
if not os . path . exists ( aug_vol_path ) :
np . save ( aug_vol_path , aug_vol . to ( ' cpu ' ) . numpy ( ) )
2023-06-19 19:38:13 +00:00
def process_batch ( file_chunk , f0p , diff = False , mel_extractor = None ) :
2023-05-31 12:16:11 +00:00
print ( " Loading speech encoder for content... " )
device = " cuda " if torch . cuda . is_available ( ) else " cpu "
2023-06-19 19:38:13 +00:00
hmodel = utils . get_speech_encoder ( speech_encoder , device = device )
2023-05-31 12:16:11 +00:00
print ( " Loaded speech encoder. " )
2023-03-10 10:11:04 +00:00
2023-06-19 19:38:13 +00:00
for filename in tqdm ( file_chunk ) :
process_one ( filename , hmodel , f0p , diff , mel_extractor )
def parallel_process ( filenames , num_processes , f0p , diff , mel_extractor ) :
with ProcessPoolExecutor ( max_workers = num_processes ) as executor :
tasks = [ ]
for i in range ( num_processes ) :
start = int ( i * len ( filenames ) / num_processes )
end = int ( ( i + 1 ) * len ( filenames ) / num_processes )
file_chunk = filenames [ start : end ]
tasks . append ( executor . submit ( process_batch , file_chunk , f0p , diff , mel_extractor ) )
for task in tqdm ( tasks ) :
task . result ( )
2023-03-10 10:11:04 +00:00
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
2023-03-24 05:00:14 +00:00
parser . add_argument (
" --in_dir " , type = str , default = " dataset/44k " , help = " path to input dir "
)
2023-06-19 19:38:13 +00:00
parser . add_argument (
2023-05-16 05:17:51 +00:00
' --use_diff ' , action = ' store_true ' , help = ' Whether to use the diffusion model '
2023-05-14 06:39:07 +00:00
)
2023-06-19 19:38:13 +00:00
parser . add_argument (
2023-05-16 05:17:51 +00:00
' --f0_predictor ' , type = str , default = " dio " , help = ' Select F0 predictor, can select crepe,pm,dio,harvest, default pm(note: crepe is original F0 using mean filter) '
2023-05-14 17:23:46 +00:00
)
2023-06-19 19:38:13 +00:00
parser . add_argument (
2023-05-30 18:12:46 +00:00
' --num_processes ' , type = int , default = 1 , help = ' You are advised to set the number of processes to the same as the number of CPU cores '
)
2023-05-16 05:17:51 +00:00
device = torch . device ( " cuda " if torch . cuda . is_available ( ) else " cpu " )
2023-03-10 10:11:04 +00:00
args = parser . parse_args ( )
2023-05-14 06:39:07 +00:00
f0p = args . f0_predictor
print ( speech_encoder )
2023-05-14 07:22:20 +00:00
print ( f0p )
2023-06-19 19:38:13 +00:00
print ( args . use_diff )
2023-05-16 05:17:51 +00:00
if args . use_diff :
print ( " use_diff " )
print ( " Loading Mel Extractor... " )
mel_extractor = Vocoder ( dconfig . vocoder . type , dconfig . vocoder . ckpt , device = device )
print ( " Loaded Mel Extractor. " )
else :
mel_extractor = None
2023-03-24 05:00:14 +00:00
filenames = glob ( f " { args . in_dir } /*/*.wav " , recursive = True ) # [:10]
2023-03-10 10:11:04 +00:00
shuffle ( filenames )
2023-03-24 05:00:14 +00:00
multiprocessing . set_start_method ( " spawn " , force = True )
2023-06-19 19:38:13 +00:00
2023-05-30 18:12:46 +00:00
num_processes = args . num_processes
2023-06-19 19:38:13 +00:00
if num_processes == 0 :
num_processes = os . cpu_count ( )
parallel_process ( filenames , num_processes , f0p , args . use_diff , mel_extractor )