fix(rmvpe): pass `device` when loading torch model (#301)
This commit is contained in:
parent
e7b478596a
commit
55dd869858
|
@ -16,7 +16,7 @@ class RMVPE:
|
|||
else:
|
||||
self.device = device
|
||||
model = E2E0(4, 1, (2, 2))
|
||||
ckpt = torch.load(model_path)
|
||||
ckpt = torch.load(model_path, map_location=torch.device(self.device))
|
||||
model.load_state_dict(ckpt['model'])
|
||||
model = model.to(dtype).to(self.device)
|
||||
model.eval()
|
||||
|
@ -54,4 +54,4 @@ class RMVPE:
|
|||
mel = mel_extractor(audio_res, center=True).to(self.dtype)
|
||||
hidden = self.mel2hidden(mel)
|
||||
f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi)
|
||||
return f0
|
||||
return f0
|
||||
|
|
Loading…
Reference in New Issue