Update train_diff.py

This commit is contained in:
Stardust·减 2023-07-22 22:04:44 +08:00 committed by GitHub
parent ff07b3d9e6
commit 12a3ba587e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from diffusion.logger import utils
from diffusion.solver import train from diffusion.solver import train
from diffusion.unit2mel import Unit2Mel from diffusion.unit2mel import Unit2Mel
from diffusion.vocoder import Vocoder from diffusion.vocoder import Vocoder
from loguru import logger
def parse_args(args=None, namespace=None): def parse_args(args=None, namespace=None):
"""Parse command-line arguments.""" """Parse command-line arguments."""
@ -28,8 +28,8 @@ if __name__ == '__main__':
# load config # load config
args = utils.load_config(cmd.config) args = utils.load_config(cmd.config)
print(' > config:', cmd.config) logger.info(' > config:'+ cmd.config)
print(' > exp:', args.env.expdir) logger.info(' > exp:'+ args.env.expdir)
# load vocoder # load vocoder
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
@ -47,7 +47,7 @@ if __name__ == '__main__':
args.model.k_step_max args.model.k_step_max
) )
print(f' > INFO: now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}') logger.info(f' > Now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}')
# load parameters # load parameters
optimizer = torch.optim.AdamW(model.parameters()) optimizer = torch.optim.AdamW(model.parameters())