248 lines
8.9 KiB
Python
248 lines
8.9 KiB
Python
import hashlib
|
||
import json
|
||
import os
|
||
import time
|
||
import traceback
|
||
import warnings
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import parselmouth
|
||
import resampy
|
||
import torch
|
||
import torchcrepe
|
||
|
||
import utils
|
||
from modules.vocoders.nsf_hifigan import nsf_hifigan
|
||
from utils.hparams import hparams
|
||
from utils.pitch_utils import f0_to_coarse
|
||
|
||
warnings.filterwarnings("ignore")
|
||
|
||
|
||
class BinarizationError(Exception):
|
||
pass
|
||
|
||
|
||
def get_md5(content):
|
||
return hashlib.new("md5", content).hexdigest()
|
||
|
||
|
||
def read_temp(file_name):
|
||
if not os.path.exists(file_name):
|
||
with open(file_name, "w") as f:
|
||
f.write(json.dumps({"info": "temp_dict"}))
|
||
return {}
|
||
else:
|
||
try:
|
||
with open(file_name, "r") as f:
|
||
data = f.read()
|
||
data_dict = json.loads(data)
|
||
if os.path.getsize(file_name) > 50 * 1024 * 1024:
|
||
f_name = file_name.split("/")[-1]
|
||
print(f"clean {f_name}")
|
||
for wav_hash in list(data_dict.keys()):
|
||
if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600:
|
||
del data_dict[wav_hash]
|
||
except Exception as e:
|
||
print(e)
|
||
print(f"{file_name} error,auto rebuild file")
|
||
data_dict = {"info": "temp_dict"}
|
||
return data_dict
|
||
|
||
|
||
def write_temp(file_name, data):
|
||
with open(file_name, "w") as f:
|
||
f.write(json.dumps(data))
|
||
|
||
|
||
f0_dict = read_temp("./infer_tools/f0_temp.json")
|
||
|
||
|
||
def get_pitch_parselmouth(wav_data, mel, hparams):
|
||
"""
|
||
|
||
:param wav_data: [T]
|
||
:param mel: [T, 80]
|
||
:param hparams:
|
||
:return:
|
||
"""
|
||
time_step = hparams['hop_size'] / hparams['audio_sample_rate']
|
||
f0_min = hparams['f0_min']
|
||
f0_max = hparams['f0_max']
|
||
|
||
f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
|
||
time_step=time_step, voicing_threshold=0.6,
|
||
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
|
||
|
||
pad_size = (int(len(wav_data) // hparams['hop_size']) - len(f0) + 1) // 2
|
||
f0 = np.pad(f0, [[pad_size, len(mel) - len(f0) - pad_size]], mode='constant')
|
||
pitch_coarse = f0_to_coarse(f0, hparams)
|
||
return f0, pitch_coarse
|
||
|
||
|
||
def get_pitch_crepe(wav_data, mel, hparams, threshold=0.05):
|
||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
device = torch.device("cuda")
|
||
# crepe只支持16khz采样率,需要重采样
|
||
wav16k = resampy.resample(wav_data, hparams['audio_sample_rate'], 16000)
|
||
wav16k_torch = torch.FloatTensor(wav16k).unsqueeze(0).to(device)
|
||
|
||
# 频率范围
|
||
f0_min = hparams['f0_min']
|
||
f0_max = hparams['f0_max']
|
||
|
||
# 重采样后按照hopsize=80,也就是5ms一帧分析f0
|
||
f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, f0_min, f0_max, pad=True, model='full', batch_size=1024,
|
||
device=device, return_periodicity=True)
|
||
|
||
# 滤波,去掉静音,设置uv阈值,参考原仓库readme
|
||
pd = torchcrepe.filter.median(pd, 3)
|
||
pd = torchcrepe.threshold.Silence(-60.)(pd, wav16k_torch, 16000, 80)
|
||
f0 = torchcrepe.threshold.At(threshold)(f0, pd)
|
||
f0 = torchcrepe.filter.mean(f0, 3)
|
||
|
||
# 将nan频率(uv部分)转换为0频率
|
||
f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)
|
||
|
||
# 去掉0频率,并线性插值
|
||
nzindex = torch.nonzero(f0[0]).squeeze()
|
||
f0 = torch.index_select(f0[0], dim=0, index=nzindex).cpu().numpy()
|
||
time_org = 0.005 * nzindex.cpu().numpy()
|
||
time_frame = np.arange(len(mel)) * hparams['hop_size'] / hparams['audio_sample_rate']
|
||
if f0.shape[0] == 0:
|
||
f0 = torch.FloatTensor(time_frame.shape[0]).fill_(0)
|
||
print('f0 all zero!')
|
||
else:
|
||
f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
|
||
pitch_coarse = f0_to_coarse(f0, hparams)
|
||
return f0, pitch_coarse
|
||
|
||
|
||
class File2Batch:
|
||
'''
|
||
pipeline: file -> temporary_dict -> processed_input -> batch
|
||
'''
|
||
|
||
@staticmethod
|
||
def file2temporary_dict(raw_data_dir, ds_id):
|
||
'''
|
||
read from file, store data in temporary dicts
|
||
'''
|
||
raw_data_dir = Path(raw_data_dir)
|
||
utterance_labels = []
|
||
utterance_labels.extend(list(raw_data_dir.rglob(f"*.wav")))
|
||
utterance_labels.extend(list(raw_data_dir.rglob(f"*.ogg")))
|
||
|
||
all_temp_dict = {}
|
||
for utterance_label in utterance_labels:
|
||
item_name = str(utterance_label)
|
||
temp_dict = {'wav_fn': str(utterance_label), 'spk_id': ds_id}
|
||
all_temp_dict[item_name] = temp_dict
|
||
return all_temp_dict
|
||
|
||
@staticmethod
|
||
def temporary_dict2processed_input(item_name, temp_dict, encoder, infer=False, **kwargs):
|
||
'''
|
||
process data in temporary_dicts
|
||
'''
|
||
|
||
def get_pitch(wav, mel):
|
||
# get ground truth f0 by self.get_pitch_algorithm
|
||
global f0_dict
|
||
use_crepe = hparams['use_crepe'] if not infer else kwargs['use_crepe']
|
||
if use_crepe:
|
||
md5 = get_md5(wav)
|
||
if infer and md5 in f0_dict.keys():
|
||
print("load temp crepe f0")
|
||
gt_f0 = np.array(f0_dict[md5]["f0"])
|
||
coarse_f0 = np.array(f0_dict[md5]["coarse"])
|
||
else:
|
||
torch.cuda.is_available() and torch.cuda.empty_cache()
|
||
gt_f0, coarse_f0 = get_pitch_crepe(wav, mel, hparams, threshold=0.05)
|
||
if infer:
|
||
f0_dict[md5] = {"f0": gt_f0.tolist(), "coarse": coarse_f0.tolist(), "time": int(time.time())}
|
||
write_temp("./infer_tools/f0_temp.json", f0_dict)
|
||
else:
|
||
gt_f0, coarse_f0 = get_pitch_parselmouth(wav, mel, hparams)
|
||
if sum(gt_f0) == 0:
|
||
raise BinarizationError("Empty **gt** f0")
|
||
processed_input['f0'] = gt_f0
|
||
processed_input['pitch'] = coarse_f0
|
||
|
||
def get_align(mel, phone_encoded):
|
||
mel2ph = np.zeros([mel.shape[0]], int)
|
||
start_frame = 0
|
||
ph_durs = mel.shape[0] / phone_encoded.shape[0]
|
||
for i_ph in range(phone_encoded.shape[0]):
|
||
end_frame = int(i_ph * ph_durs + ph_durs + 0.5)
|
||
mel2ph[start_frame:end_frame + 1] = i_ph + 1
|
||
start_frame = end_frame + 1
|
||
|
||
processed_input['mel2ph'] = mel2ph
|
||
|
||
wav, mel = nsf_hifigan.wav2spec(temp_dict['wav_fn'])
|
||
processed_input = {
|
||
'item_name': item_name, 'mel': mel,
|
||
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0]
|
||
}
|
||
processed_input = {**temp_dict, **processed_input,
|
||
'spec_min': np.min(mel, axis=0),
|
||
'spec_max': np.max(mel, axis=0)} # merge two dicts
|
||
try:
|
||
get_pitch(wav, mel)
|
||
try:
|
||
hubert_encoded = processed_input['hubert'] = encoder.encode(temp_dict['wav_fn'])
|
||
except:
|
||
traceback.print_exc()
|
||
raise Exception(f"hubert encode error")
|
||
get_align(mel, hubert_encoded)
|
||
except Exception as e:
|
||
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {temp_dict['wav_fn']}")
|
||
return None
|
||
if hparams['use_energy_embed']:
|
||
max_frames = hparams['max_frames']
|
||
spec = torch.Tensor(processed_input['mel'])[:max_frames]
|
||
processed_input['energy'] = (spec.exp() ** 2).sum(-1).sqrt()
|
||
return processed_input
|
||
|
||
@staticmethod
|
||
def processed_input2batch(samples):
|
||
'''
|
||
Args:
|
||
samples: one batch of processed_input
|
||
NOTE:
|
||
the batch size is controlled by hparams['max_sentences']
|
||
'''
|
||
if len(samples) == 0:
|
||
return {}
|
||
id = torch.LongTensor([s['id'] for s in samples])
|
||
item_names = [s['item_name'] for s in samples]
|
||
hubert = utils.collate_2d([s['hubert'] for s in samples], 0.0)
|
||
f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
|
||
pitch = utils.collate_1d([s['pitch'] for s in samples])
|
||
uv = utils.collate_1d([s['uv'] for s in samples])
|
||
mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
|
||
if samples[0]['mel2ph'] is not None else None
|
||
mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
|
||
mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
|
||
|
||
batch = {
|
||
'id': id,
|
||
'item_name': item_names,
|
||
'nsamples': len(samples),
|
||
'hubert': hubert,
|
||
'mels': mels,
|
||
'mel_lengths': mel_lengths,
|
||
'mel2ph': mel2ph,
|
||
'pitch': pitch,
|
||
'f0': f0,
|
||
'uv': uv,
|
||
}
|
||
if hparams['use_energy_embed']:
|
||
batch['energy'] = utils.collate_1d([s['energy'] for s in samples], 0.0)
|
||
if hparams['use_spk_id']:
|
||
spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
|
||
batch['spk_ids'] = spk_ids
|
||
return batch
|