Update solver.py
Fixed the incorrect calculation of the epoch number when diffusion training was resumed from an existing model checkpoint.
This commit is contained in:
parent
64591bd664
commit
c63fd1f40c
|
@ -101,6 +101,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
|
|||
|
||||
# run
|
||||
num_batches = len(loader_train)
|
||||
start_epoch = initial_global_step // num_batches
|
||||
model.train()
|
||||
saver.log_info('======= start training =======')
|
||||
scaler = GradScaler()
|
||||
|
@ -113,7 +114,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
|
|||
else:
|
||||
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
|
||||
saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step")
|
||||
for epoch in range(args.train.epochs):
|
||||
for epoch in range(start_epoch, args.train.epochs):
|
||||
for batch_idx, data in enumerate(loader_train):
|
||||
saver.global_step_increment()
|
||||
optimizer.zero_grad()
|
||||
|
|
Loading…
Reference in New Issue