Merge pull request #333 from svc-develop-team/add-fancy-logger

Add new logger/添加一个花里胡哨的牛逼logger
This commit is contained in:
YuriHead 2023-07-22 22:56:19 +08:00 committed by GitHub
commit 4961bf9657
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 15 deletions

View File

@ -5,6 +5,7 @@ import re
import wave import wave
from random import shuffle from random import shuffle
from loguru import logger
from tqdm import tqdm from tqdm import tqdm
import diffusion.logger.utils as du import diffusion.logger.utils as du
@ -46,9 +47,9 @@ if __name__ == "__main__":
if not file.endswith("wav"): if not file.endswith("wav"):
continue continue
if not pattern.match(file): if not pattern.match(file):
print(f"warning文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)") logger.warning(f"文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
if get_wav_duration(file) < 0.3: if get_wav_duration(file) < 0.3:
print("skip too short audio:", file) logger.info("Skip too short audio:" + file)
continue continue
new_wavs.append(file) new_wavs.append(file)
wavs = new_wavs wavs = new_wavs
@ -59,13 +60,13 @@ if __name__ == "__main__":
shuffle(train) shuffle(train)
shuffle(val) shuffle(val)
print("Writing", args.train_list) logger.info("Writing" + args.train_list)
with open(args.train_list, "w") as f: with open(args.train_list, "w") as f:
for fname in tqdm(train): for fname in tqdm(train):
wavpath = fname wavpath = fname
f.write(wavpath + "\n") f.write(wavpath + "\n")
print("Writing", args.val_list) logger.info("Writing" + args.val_list)
with open(args.val_list, "w") as f: with open(args.val_list, "w") as f:
for fname in tqdm(val): for fname in tqdm(val):
wavpath = fname wavpath = fname
@ -97,8 +98,8 @@ if __name__ == "__main__":
if args.vol_aug: if args.vol_aug:
config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True
print("Writing configs/config.json") logger.info("Writing to configs/config.json")
with open("configs/config.json", "w") as f: with open("configs/config.json", "w") as f:
json.dump(config_template, f, indent=2) json.dump(config_template, f, indent=2)
print("Writing configs/diffusion.yaml") logger.info("Writing to configs/diffusion.yaml")
du.save_config("configs/diffusion.yaml",d_config_template) du.save_config("configs/diffusion.yaml",d_config_template)

View File

@ -5,6 +5,7 @@ import random
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from glob import glob from glob import glob
from random import shuffle from random import shuffle
from loguru import logger
import librosa import librosa
import numpy as np import numpy as np
@ -28,7 +29,6 @@ speech_encoder = hps["model"]["speech_encoder"]
def process_one(filename, hmodel,f0p,rank,diff=False,mel_extractor=None): def process_one(filename, hmodel,f0p,rank,diff=False,mel_extractor=None):
# print(filename)
wav, sr = librosa.load(filename, sr=sampling_rate) wav, sr = librosa.load(filename, sr=sampling_rate)
audio_norm = torch.FloatTensor(wav) audio_norm = torch.FloatTensor(wav)
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
@ -104,15 +104,15 @@ def process_one(filename, hmodel,f0p,rank,diff=False,mel_extractor=None):
np.save(aug_vol_path,aug_vol.to('cpu').numpy()) np.save(aug_vol_path,aug_vol.to('cpu').numpy())
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None): def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
print("Loading speech encoder for content...") logger.info("Loading speech encoder for content...")
rank = mp.current_process()._identity rank = mp.current_process()._identity
rank = rank[0] if len(rank) > 0 else 0 rank = rank[0] if len(rank) > 0 else 0
if torch.cuda.is_available(): if torch.cuda.is_available():
gpu_id = rank % torch.cuda.device_count() gpu_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
print(f"Rank {rank} uses device {device}") logger.info(f"Rank {rank} uses device {device}")
hmodel = utils.get_speech_encoder(speech_encoder, device=device) hmodel = utils.get_speech_encoder(speech_encoder, device=device)
print("Loaded speech encoder.") logger.info(f"Loaded speech encoder for rank {rank}")
for filename in tqdm(file_chunk): for filename in tqdm(file_chunk):
process_one(filename, hmodel, f0p, rank, diff, mel_extractor) process_one(filename, hmodel, f0p, rank, diff, mel_extractor)
@ -144,7 +144,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
f0p = args.f0_predictor f0p = args.f0_predictor
print(speech_encoder) print(speech_encoder)
print(f0p) logger.info("Using " + speech_encoder + " SpeechEncoder")
logger.info("Using " + f0p + "f0 extractor")
logger.info("Using diff Mode:")
print(args.use_diff) print(args.use_diff)
if args.use_diff: if args.use_diff:
print("use_diff") print("use_diff")

View File

@ -25,3 +25,4 @@ langdetect
pyyaml pyyaml
pynvml pynvml
faiss-cpu faiss-cpu
loguru

View File

@ -29,3 +29,4 @@ langdetect
pyyaml pyyaml
pynvml pynvml
faiss-cpu faiss-cpu
loguru

View File

@ -8,7 +8,7 @@ from diffusion.logger import utils
from diffusion.solver import train from diffusion.solver import train
from diffusion.unit2mel import Unit2Mel from diffusion.unit2mel import Unit2Mel
from diffusion.vocoder import Vocoder from diffusion.vocoder import Vocoder
from loguru import logger
def parse_args(args=None, namespace=None): def parse_args(args=None, namespace=None):
"""Parse command-line arguments.""" """Parse command-line arguments."""
@ -28,8 +28,8 @@ if __name__ == '__main__':
# load config # load config
args = utils.load_config(cmd.config) args = utils.load_config(cmd.config)
print(' > config:', cmd.config) logger.info(' > config:'+ cmd.config)
print(' > exp:', args.env.expdir) logger.info(' > exp:'+ args.env.expdir)
# load vocoder # load vocoder
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
@ -47,7 +47,7 @@ if __name__ == '__main__':
args.model.k_step_max args.model.k_step_max
) )
print(f' > INFO: now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}') logger.info(f' > Now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}')
# load parameters # load parameters
optimizer = torch.optim.AdamW(model.parameters()) optimizer = torch.optim.AdamW(model.parameters())