From 55dd869858bc520c7d125fa71814a360f9ec7667 Mon Sep 17 00:00:00 2001 From: quicksand Date: Fri, 14 Jul 2023 19:11:14 +0800 Subject: [PATCH] fix(rmvpe): pass `device` when loading torch model (#301) --- modules/F0Predictor/rmvpe/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/F0Predictor/rmvpe/inference.py b/modules/F0Predictor/rmvpe/inference.py index 6beac87..40b6e94 100644 --- a/modules/F0Predictor/rmvpe/inference.py +++ b/modules/F0Predictor/rmvpe/inference.py @@ -16,7 +16,7 @@ class RMVPE: else: self.device = device model = E2E0(4, 1, (2, 2)) - ckpt = torch.load(model_path) + ckpt = torch.load(model_path, map_location=torch.device(self.device)) model.load_state_dict(ckpt['model']) model = model.to(dtype).to(self.device) model.eval() @@ -54,4 +54,4 @@ class RMVPE: mel = mel_extractor(audio_res, center=True).to(self.dtype) hidden = self.mel2hidden(mel) f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi) - return f0 \ No newline at end of file + return f0