Update onnx_export.py

This commit is contained in:
Ναρουσέ·μ·γιουμεμί·Χινακάννα 2023-03-15 15:45:18 +08:00 committed by GitHub
parent 885a1a7cb6
commit 5e3918982c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 22 deletions

View File

@ -1,21 +1,7 @@
import torch
from torchaudio.models.wav2vec2.utils import import_fairseq_model
from fairseq import checkpoint_utils
from onnxexport.model_onnx import SynthesizerTrn
import utils
def get_hubert_model():
vec_path = "hubert/checkpoint_best_legacy_500.pt"
print("load model(s) from {}".format(vec_path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[vec_path],
suffix="",
)
model = models[0]
model.eval()
return model
def main(NetExport):
path = "SoVits4.0V2"
if NetExport:
@ -27,6 +13,7 @@ def main(NetExport):
_ = SVCVITS.eval().to(device)
for i in SVCVITS.parameters():
i.requires_grad = False
test_hidden_unit = torch.rand(1, 10, 256) # rand
test_pitch = torch.rand(1, 10) # rand
test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
@ -37,14 +24,6 @@ def main(NetExport):
test_sid = torch.LongTensor([0])
input_names = ["c", "f0", "mel2ph", "t_window", "noise", "sid"]
output_names = ["audio", ]
SVCVITS.eval()
SVCVITS(test_hidden_unit.to(device),
test_pitch.to(device),
test_mel2ph.to(device),
test_uv.to(device),
test_noise.to(device),
test_sid.to(device))
torch.onnx.export(SVCVITS,
(