Debug BF16 and RMVPE

This commit is contained in:
YuriHead 2023-07-11 23:02:44 +08:00
parent 36787124f4
commit f808d8e60b
2 changed files with 2 additions and 2 deletions

View File

@ -140,7 +140,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if writers is not None:
writer, writer_eval = writers
half_type = torch.float16 if hps.train.half_type=="fp16" else torch.bfloat16
half_type = torch.bfloat16 if hps.train.half_type=="bf16" else torch.float16
# train_loader.batch_sampler.set_epoch(epoch)
global global_step

View File

@ -99,7 +99,7 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
elif f0_predictor == "rmvpe":
from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float16 ,device=kargs["device"],threshold=kargs["threshold"])
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
else:
raise Exception("Unknown f0 predictor")
return f0_predictor_object