diff --git a/train.py b/train.py index 16c71c4..e0be902 100644 --- a/train.py +++ b/train.py @@ -123,10 +123,6 @@ def run(rank, n_gpus, hps): scaler = GradScaler(enabled=hps.train.fp16_run) for epoch in range(epoch_str, hps.train.epochs + 1): - # update learning rate - if epoch > 1: - scheduler_g.step() - scheduler_d.step() # set up warm-up learning rate if epoch <= warmup_epoch: for param_group in optim_g.param_groups: @@ -140,6 +136,10 @@ def run(rank, n_gpus, hps): else: train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) + # update learning rate + if epoch > 1: + scheduler_g.step() + scheduler_d.step() def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):