diff --git a/train.py b/train.py index b9132a2..7aec545 100644 --- a/train.py +++ b/train.py @@ -99,7 +99,7 @@ def run(rank, n_gpus, hps): name=utils.latest_checkpoint_path(hps.model_dir, "D_*.pth") global_step=int(name[name.rfind("_")+1:name.rfind(".")])+1 #global_step = (epoch_str - 1) * len(train_loader) - except AssertionError: + except Exception: print("load old checkpoint failed...") epoch_str = 1 global_step = 0