Merge pull request #333 from svc-develop-team/add-fancy-logger

Add new logger/添加一个花里胡哨的牛逼logger
This commit is contained in:
YuriHead 2023-07-22 22:56:19 +08:00 committed by GitHub
commit 4961bf9657
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 15 deletions

View File

@ -5,6 +5,7 @@ import re
import wave
from random import shuffle
from loguru import logger
from tqdm import tqdm
import diffusion.logger.utils as du
@ -46,9 +47,9 @@ if __name__ == "__main__":
if not file.endswith("wav"):
continue
if not pattern.match(file):
print(f"warning文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
logger.warning(f"文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
if get_wav_duration(file) < 0.3:
print("skip too short audio:", file)
logger.info("Skip too short audio:" + file)
continue
new_wavs.append(file)
wavs = new_wavs
@ -59,13 +60,13 @@ if __name__ == "__main__":
shuffle(train)
shuffle(val)
print("Writing", args.train_list)
logger.info("Writing" + args.train_list)
with open(args.train_list, "w") as f:
for fname in tqdm(train):
wavpath = fname
f.write(wavpath + "\n")
print("Writing", args.val_list)
logger.info("Writing" + args.val_list)
with open(args.val_list, "w") as f:
for fname in tqdm(val):
wavpath = fname
@ -97,8 +98,8 @@ if __name__ == "__main__":
if args.vol_aug:
config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True
print("Writing configs/config.json")
logger.info("Writing to configs/config.json")
with open("configs/config.json", "w") as f:
json.dump(config_template, f, indent=2)
print("Writing configs/diffusion.yaml")
logger.info("Writing to configs/diffusion.yaml")
du.save_config("configs/diffusion.yaml",d_config_template)

View File

@ -5,6 +5,7 @@ import random
from concurrent.futures import ProcessPoolExecutor
from glob import glob
from random import shuffle
from loguru import logger
import librosa
import numpy as np
@ -28,7 +29,6 @@ speech_encoder = hps["model"]["speech_encoder"]
def process_one(filename, hmodel,f0p,rank,diff=False,mel_extractor=None):
# print(filename)
wav, sr = librosa.load(filename, sr=sampling_rate)
audio_norm = torch.FloatTensor(wav)
audio_norm = audio_norm.unsqueeze(0)
@ -104,15 +104,15 @@ def process_one(filename, hmodel,f0p,rank,diff=False,mel_extractor=None):
np.save(aug_vol_path,aug_vol.to('cpu').numpy())
def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
print("Loading speech encoder for content...")
logger.info("Loading speech encoder for content...")
rank = mp.current_process()._identity
rank = rank[0] if len(rank) > 0 else 0
if torch.cuda.is_available():
gpu_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{gpu_id}")
print(f"Rank {rank} uses device {device}")
logger.info(f"Rank {rank} uses device {device}")
hmodel = utils.get_speech_encoder(speech_encoder, device=device)
print("Loaded speech encoder.")
logger.info(f"Loaded speech encoder for rank {rank}")
for filename in tqdm(file_chunk):
process_one(filename, hmodel, f0p, rank, diff, mel_extractor)
@ -144,7 +144,9 @@ if __name__ == "__main__":
args = parser.parse_args()
f0p = args.f0_predictor
print(speech_encoder)
print(f0p)
logger.info("Using " + speech_encoder + " SpeechEncoder")
logger.info("Using " + f0p + "f0 extractor")
logger.info("Using diff Mode:")
print(args.use_diff)
if args.use_diff:
print("use_diff")

View File

@ -25,3 +25,4 @@ langdetect
pyyaml
pynvml
faiss-cpu
loguru

View File

@ -29,3 +29,4 @@ langdetect
pyyaml
pynvml
faiss-cpu
loguru

View File

@ -8,7 +8,7 @@ from diffusion.logger import utils
from diffusion.solver import train
from diffusion.unit2mel import Unit2Mel
from diffusion.vocoder import Vocoder
from loguru import logger
def parse_args(args=None, namespace=None):
"""Parse command-line arguments."""
@ -28,8 +28,8 @@ if __name__ == '__main__':
# load config
args = utils.load_config(cmd.config)
print(' > config:', cmd.config)
print(' > exp:', args.env.expdir)
logger.info(' > config:'+ cmd.config)
logger.info(' > exp:'+ args.env.expdir)
# load vocoder
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
@ -47,7 +47,7 @@ if __name__ == '__main__':
args.model.k_step_max
)
print(f' > INFO: now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}')
logger.info(f' > Now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}')
# load parameters
optimizer = torch.optim.AdamW(model.parameters())