From 55dd086fc66162e0853ddc0cccff39236324abd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Wed, 15 Mar 2023 15:46:11 +0800 Subject: [PATCH] Update onnx_export.py --- onnx_export.py | 49 ++++--------------------------------------------- 1 file changed, 4 insertions(+), 45 deletions(-) diff --git a/onnx_export.py b/onnx_export.py index 53b278d..7914d12 100644 --- a/onnx_export.py +++ b/onnx_export.py @@ -1,51 +1,9 @@ import torch -from torchaudio.models.wav2vec2.utils import import_fairseq_model -from fairseq import checkpoint_utils from onnxexport.model_onnx import SynthesizerTrn import utils -def get_hubert_model(): - vec_path = "hubert/checkpoint_best_legacy_500.pt" - print("load model(s) from {}".format(vec_path)) - models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - [vec_path], - suffix="", - ) - model = models[0] - model.eval() - return model - - -def main(HubertExport, NetExport): +def main(NetExport): path = "SoVits4.0" - - '''if HubertExport: - device = torch.device("cpu") - vec_path = "hubert/checkpoint_best_legacy_500.pt" - models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - [vec_path], - suffix="", - ) - original = models[0] - original.eval() - model = original - test_input = torch.rand(1, 1, 16000) - model(test_input) - torch.onnx.export(model, - test_input, - "hubert4.0.onnx", - export_params=True, - opset_version=16, - do_constant_folding=True, - input_names=['source'], - output_names=['embed'], - dynamic_axes={ - 'source': - { - 2: "sample_length" - }, - } - )''' if NetExport: device = torch.device("cpu") hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") @@ -57,6 +15,7 @@ def main(HubertExport, NetExport): _ = SVCVITS.eval().to(device) for i in SVCVITS.parameters(): i.requires_grad = False + test_hidden_unit = torch.rand(1, 10, 256) test_pitch = torch.rand(1, 10) test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0) @@ -65,7 +24,7 @@ def main(HubertExport, NetExport): test_sid = torch.LongTensor([0]) input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] output_names = ["audio", ] - SVCVITS.eval() + torch.onnx.export(SVCVITS, ( test_hidden_unit.to(device), @@ -91,4 +50,4 @@ def main(HubertExport, NetExport): if __name__ == '__main__': - main(False, True) + main(True)