From f808d8e60b83cb54a34070eadd98a8928129b0a5 Mon Sep 17 00:00:00 2001 From: YuriHead Date: Tue, 11 Jul 2023 23:02:44 +0800 Subject: [PATCH] Debug BF16 and RMVPE --- train.py | 2 +- utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 7aec545..8487f17 100644 --- a/train.py +++ b/train.py @@ -140,7 +140,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade if writers is not None: writer, writer_eval = writers - half_type = torch.float16 if hps.train.half_type=="fp16" else torch.bfloat16 + half_type = torch.bfloat16 if hps.train.half_type=="bf16" else torch.float16 # train_loader.batch_sampler.set_epoch(epoch) global global_step diff --git a/utils.py b/utils.py index 4db63e7..df1b51e 100644 --- a/utils.py +++ b/utils.py @@ -99,7 +99,7 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs): f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate) elif f0_predictor == "rmvpe": from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor - f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float16 ,device=kargs["device"],threshold=kargs["threshold"]) + f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"]) else: raise Exception("Unknown f0 predictor") return f0_predictor_object