diff-svc/modules/hubert/hubert_onnx.py

20 lines
633 B
Python

import time
import torch
import torchaudio
def get_onnx_units(hbt_soft, raw_wav_path):
source, sr = torchaudio.load(raw_wav_path)
source = torchaudio.functional.resample(source, sr, 16000)
if len(source.shape) == 2 and source.shape[1] >= 2:
source = torch.mean(source, dim=0).unsqueeze(0)
source = source.unsqueeze(0)
# 使用ONNX Runtime进行推理
start = time.time()
units = hbt_soft.run(output_names=["units"],
input_feed={"wav": source.numpy()})[0]
use_time = time.time() - start
print("hubert_onnx_session.run time:{}".format(use_time))
return units