Merge pull request #2 from mlbv/4.1-Stable

Update solver.py
This commit is contained in:
CN_ChiTu 2023-08-05 05:14:16 +08:00 committed by GitHub
commit b7d6905b80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -101,6 +101,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
# run
num_batches = len(loader_train)
start_epoch = initial_global_step // num_batches
model.train()
saver.log_info('======= start training =======')
scaler = GradScaler()
@ -113,7 +114,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
else:
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step")
for epoch in range(args.train.epochs):
for epoch in range(start_epoch, args.train.epochs):
for batch_idx, data in enumerate(loader_train):
saver.global_step_increment()
optimizer.zero_grad()