This commit is contained in:
ylzz1997 2023-05-22 23:54:59 +08:00
parent 28dd4fa032
commit a67bc3a84b
1 changed files with 4 additions and 4 deletions

View File

@ -123,10 +123,6 @@ def run(rank, n_gpus, hps):
scaler = GradScaler(enabled=hps.train.fp16_run) scaler = GradScaler(enabled=hps.train.fp16_run)
for epoch in range(epoch_str, hps.train.epochs + 1): for epoch in range(epoch_str, hps.train.epochs + 1):
# update learning rate
if epoch > 1:
scheduler_g.step()
scheduler_d.step()
# set up warm-up learning rate # set up warm-up learning rate
if epoch <= warmup_epoch: if epoch <= warmup_epoch:
for param_group in optim_g.param_groups: for param_group in optim_g.param_groups:
@ -140,6 +136,10 @@ def run(rank, n_gpus, hps):
else: else:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
[train_loader, None], None, None) [train_loader, None], None, None)
# update learning rate
if epoch > 1:
scheduler_g.step()
scheduler_d.step()
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):