diff --git a/vencoder/HubertSoft_Onnx.py b/vencoder/HubertSoft_Onnx.py index 57f37d5..06f10a4 100644 --- a/vencoder/HubertSoft_Onnx.py +++ b/vencoder/HubertSoft_Onnx.py @@ -22,7 +22,7 @@ class HubertSoft_Onnx(SpeechEncoder): feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) - feats = feats.unsqueeze(0).detach().numpy() + feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file