diff --git a/vencoder/ContentVec768L9_Onnx.py b/vencoder/ContentVec768L9_Onnx.py index 6cc24e0..7cdac4c 100644 --- a/vencoder/ContentVec768L9_Onnx.py +++ b/vencoder/ContentVec768L9_Onnx.py @@ -22,7 +22,7 @@ class ContentVec768L9_Onnx(SpeechEncoder): feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) - feats = feats.unsqueeze(0).detach().numpy() + feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file