import argparse import torch from torch.optim import lr_scheduler from diffusion.data_loaders import get_data_loaders from diffusion.logger import utils from diffusion.solver import train from diffusion.unit2mel import Unit2Mel from diffusion.vocoder import Vocoder def parse_args(args=None, namespace=None): """Parse command-line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, required=True, help="path to the config file") return parser.parse_args(args=args, namespace=namespace) if __name__ == '__main__': # parse commands cmd = parse_args() # load config args = utils.load_config(cmd.config) print(' > config:', cmd.config) print(' > exp:', args.env.expdir) # load vocoder vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) # load model model = Unit2Mel( args.data.encoder_out_channels, args.model.n_spk, args.model.use_pitch_aug, vocoder.dimension, args.model.n_layers, args.model.n_chans, args.model.n_hidden, args.model.timesteps, args.model.k_step_max ) print(f' > INFO: now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}') # load parameters optimizer = torch.optim.AdamW(model.parameters()) initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) for param_group in optimizer.param_groups: param_group['initial_lr'] = args.train.lr param_group['lr'] = args.train.lr * (args.train.gamma ** max(((initial_global_step-2)//args.train.decay_step),0) ) param_group['weight_decay'] = args.train.weight_decay scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma,last_epoch=initial_global_step-2) # device if args.device == 'cuda': torch.cuda.set_device(args.env.gpu_id) model.to(args.device) for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(args.device) # datas loader_train, loader_valid = get_data_loaders(args, whole_audio=False) # run train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid)