117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
import json
|
||
import os
|
||
import shutil
|
||
from functools import reduce
|
||
from pathlib import Path
|
||
|
||
import matplotlib
|
||
import matplotlib.pyplot as plt
|
||
import yaml
|
||
from pylab import xticks, np
|
||
from tqdm import tqdm
|
||
|
||
from modules.vocoders.nsf_hifigan import NsfHifiGAN
|
||
from preprocessing.process_pipeline import get_pitch_parselmouth, get_pitch_crepe
|
||
from utils.hparams import set_hparams, hparams
|
||
|
||
head_list = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
||
|
||
|
||
def compare_pitch(f0_static_dict, pitch_time_temp, trans_key=0):
|
||
return sum({k: v * f0_static_dict[str(k + trans_key)] for k, v in pitch_time_temp.items() if
|
||
str(k + trans_key) in f0_static_dict}.values())
|
||
|
||
|
||
def f0_to_pitch(ff):
|
||
f0_pitch = 69 + 12 * np.log2(ff / 440)
|
||
return round(f0_pitch, 0)
|
||
|
||
|
||
def pitch_to_name(pitch):
|
||
return f"{head_list[int(pitch % 12)]}{int(pitch / 12) - 1}"
|
||
|
||
|
||
def get_f0(audio_path, crepe=False):
|
||
wav, mel = NsfHifiGAN.wav2spec(audio_path)
|
||
if crepe:
|
||
f0, pitch_coarse = get_pitch_crepe(wav, mel, hparams)
|
||
else:
|
||
f0, pitch_coarse = get_pitch_parselmouth(wav, mel, hparams)
|
||
return f0
|
||
|
||
|
||
def merge_f0_dict(dict_list):
|
||
def sum_dict(a, b):
|
||
temp = dict()
|
||
for key in a.keys() | b.keys():
|
||
temp[key] = sum([d.get(key, 0) for d in (a, b)])
|
||
return temp
|
||
|
||
return reduce(sum_dict, dict_list)
|
||
|
||
|
||
def collect_f0(f0):
|
||
pitch_num = {}
|
||
pitch_list = [f0_to_pitch(x) for x in f0[f0 > 0]]
|
||
for key in pitch_list:
|
||
pitch_num[key] = pitch_num.get(key, 0) + 1
|
||
return pitch_num
|
||
|
||
|
||
def static_f0_time(f0):
|
||
if isinstance(f0, dict):
|
||
pitch_num = merge_f0_dict({k: collect_f0(v) for k, v in f0.items()}.values())
|
||
else:
|
||
pitch_num = collect_f0(f0)
|
||
static_pitch_time = {}
|
||
sort_key = sorted(pitch_num.keys())
|
||
for key in sort_key:
|
||
static_pitch_time[key] = round(pitch_num[key] * hparams['hop_size'] / hparams['audio_sample_rate'], 2)
|
||
return static_pitch_time
|
||
|
||
|
||
def get_end_file(dir_path, end):
|
||
file_lists = []
|
||
for root, dirs, files in os.walk(dir_path):
|
||
files = [f for f in files if f[0] != '.']
|
||
dirs[:] = [d for d in dirs if d[0] != '.']
|
||
for f_file in files:
|
||
if f_file.endswith(end):
|
||
file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
|
||
return file_lists
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 给config文件增加f0_static统计音域
|
||
config_path = "../training/config_nsf.yaml"
|
||
hparams = set_hparams(config=config_path, exp_name='', infer=True, reset=True, hparams_str='', print_hparams=False)
|
||
f0_dict = {}
|
||
# 获取batch文件夹下所有wav文件
|
||
wav_paths = get_end_file("../batch", "wav")
|
||
# parselmouth获取f0
|
||
with tqdm(total=len(wav_paths)) as p_bar:
|
||
p_bar.set_description('Processing')
|
||
for wav_path in wav_paths:
|
||
f0_dict[wav_path] = get_f0(wav_path, crepe=False)
|
||
p_bar.update(1)
|
||
pitch_time = static_f0_time(f0_dict)
|
||
total_time = round(sum(pitch_time.values()), 2)
|
||
pitch_time["total_time"] = total_time
|
||
print(f"total time: {total_time}s")
|
||
shutil.copy(config_path, f"{Path(config_path).parent}\\back_{Path(config_path).name}")
|
||
with open(config_path, encoding='utf-8') as f:
|
||
_hparams = yaml.safe_load(f)
|
||
_hparams['f0_static'] = json.dumps(pitch_time)
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
yaml.safe_dump(_hparams, f)
|
||
print("原config文件已在原目录建立备份:back_config.yaml")
|
||
print("音域统计已保存至config文件,此模型可使用自动变调功能")
|
||
matplotlib.use('TkAgg')
|
||
plt.title("数据集音域统计", fontproperties='SimHei')
|
||
plt.xlabel("音高", fontproperties='SimHei')
|
||
plt.ylabel("时长(s)", fontproperties='SimHei')
|
||
xticks_labels = [pitch_to_name(i) for i in range(36, 96)]
|
||
xticks(np.linspace(36, 96, 60, endpoint=True), xticks_labels)
|
||
plt.plot(pitch_time.keys(), pitch_time.values(), color='dodgerblue')
|
||
plt.show()
|