Merge pull request #139 from Jared-02/4.0
Revise the warm-up setup process
This commit is contained in:
commit
ddf41b73d6
14
train.py
14
train.py
|
@ -114,20 +114,30 @@ def run(rank, n_gpus, hps):
|
||||||
epoch_str = 1
|
epoch_str = 1
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
|
warmup_epoch = hps.train.warmup_epochs
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=hps.train.fp16_run)
|
scaler = GradScaler(enabled=hps.train.fp16_run)
|
||||||
|
|
||||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||||
|
# update learning rate
|
||||||
|
if epoch > 1:
|
||||||
|
scheduler_g.step()
|
||||||
|
scheduler_d.step()
|
||||||
|
# set up warm-up learning rate
|
||||||
|
if epoch <= warmup_epoch:
|
||||||
|
for param_group in optim_g.param_groups:
|
||||||
|
param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch
|
||||||
|
for param_group in optim_d.param_groups:
|
||||||
|
param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch
|
||||||
|
# training
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
||||||
[train_loader, eval_loader], logger, [writer, writer_eval])
|
[train_loader, eval_loader], logger, [writer, writer_eval])
|
||||||
else:
|
else:
|
||||||
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
||||||
[train_loader, None], None, None)
|
[train_loader, None], None, None)
|
||||||
scheduler_g.step()
|
|
||||||
scheduler_d.step()
|
|
||||||
|
|
||||||
|
|
||||||
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
||||||
|
|
Loading…
Reference in New Issue