fix(rmvpe): pass `device` when loading torch model (#301)

This commit is contained in:
quicksand 2023-07-14 19:11:14 +08:00 committed by GitHub
parent e7b478596a
commit 55dd869858
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -16,7 +16,7 @@ class RMVPE:
else:
self.device = device
model = E2E0(4, 1, (2, 2))
ckpt = torch.load(model_path)
ckpt = torch.load(model_path, map_location=torch.device(self.device))
model.load_state_dict(ckpt['model'])
model = model.to(dtype).to(self.device)
model.eval()