From 8cff6c5cb0f12b4fec46f0da2bd5f9c1406ff74c Mon Sep 17 00:00:00 2001 From: Jared <78630856+Jared-02@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:49:41 +0800 Subject: [PATCH] Revise the warm-up setup process --- train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 9f6e743..dda30d9 100644 --- a/train.py +++ b/train.py @@ -114,20 +114,30 @@ def run(rank, n_gpus, hps): epoch_str = 1 global_step = 0 + warmup_epoch = hps.train.warmup_epochs scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) scaler = GradScaler(enabled=hps.train.fp16_run) for epoch in range(epoch_str, hps.train.epochs + 1): + # update learning rate + if epoch > 1: + scheduler_g.step() + scheduler_d.step() + # set up warm-up learning rate + if epoch <= warmup_epoch: + for param_group in optim_g.param_groups: + param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch + for param_group in optim_d.param_groups: + param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch + # training if rank == 0: train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) else: train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) - scheduler_g.step() - scheduler_d.step() def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):