diff --git a/train_diff.py b/train_diff.py index adf5fb3..4cdc0eb 100644 --- a/train_diff.py +++ b/train_diff.py @@ -8,7 +8,7 @@ from diffusion.logger import utils from diffusion.solver import train from diffusion.unit2mel import Unit2Mel from diffusion.vocoder import Vocoder - +from loguru import logger def parse_args(args=None, namespace=None): """Parse command-line arguments.""" @@ -28,8 +28,8 @@ if __name__ == '__main__': # load config args = utils.load_config(cmd.config) - print(' > config:', cmd.config) - print(' > exp:', args.env.expdir) + logger.info(' > config:'+ cmd.config) + logger.info(' > exp:'+ args.env.expdir) # load vocoder vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) @@ -47,7 +47,7 @@ if __name__ == '__main__': args.model.k_step_max ) - print(f' > INFO: now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}') + logger.info(f' > Now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}') # load parameters optimizer = torch.optim.AdamW(model.parameters())