Update solver.py

Fixed the incorrect calculation of the epoch number when diffusion training was resumed from an existing model checkpoint.
This commit is contained in:
mlbv 2023-08-04 06:06:54 +08:00 committed by GitHub
parent 64591bd664
commit c63fd1f40c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -101,6 +101,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
# run
num_batches = len(loader_train)
start_epoch = initial_global_step // num_batches
model.train()
saver.log_info('======= start training =======')
scaler = GradScaler()
@ -113,7 +114,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
else:
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step")
for epoch in range(args.train.epochs):
for epoch in range(start_epoch, args.train.epochs):
for batch_idx, data in enumerate(loader_train):
saver.global_step_increment()
optimizer.zero_grad()