diff --git a/onnx_export.py b/onnx_export.py index 53b278d..7914d12 100644 --- a/onnx_export.py +++ b/onnx_export.py @@ -1,51 +1,9 @@ import torch -from torchaudio.models.wav2vec2.utils import import_fairseq_model -from fairseq import checkpoint_utils from onnxexport.model_onnx import SynthesizerTrn import utils -def get_hubert_model(): - vec_path = "hubert/checkpoint_best_legacy_500.pt" - print("load model(s) from {}".format(vec_path)) - models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - [vec_path], - suffix="", - ) - model = models[0] - model.eval() - return model - - -def main(HubertExport, NetExport): +def main(NetExport): path = "SoVits4.0" - - '''if HubertExport: - device = torch.device("cpu") - vec_path = "hubert/checkpoint_best_legacy_500.pt" - models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - [vec_path], - suffix="", - ) - original = models[0] - original.eval() - model = original - test_input = torch.rand(1, 1, 16000) - model(test_input) - torch.onnx.export(model, - test_input, - "hubert4.0.onnx", - export_params=True, - opset_version=16, - do_constant_folding=True, - input_names=['source'], - output_names=['embed'], - dynamic_axes={ - 'source': - { - 2: "sample_length" - }, - } - )''' if NetExport: device = torch.device("cpu") hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") @@ -57,6 +15,7 @@ def main(HubertExport, NetExport): _ = SVCVITS.eval().to(device) for i in SVCVITS.parameters(): i.requires_grad = False + test_hidden_unit = torch.rand(1, 10, 256) test_pitch = torch.rand(1, 10) test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0) @@ -65,7 +24,7 @@ def main(HubertExport, NetExport): test_sid = torch.LongTensor([0]) input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] output_names = ["audio", ] - SVCVITS.eval() + torch.onnx.export(SVCVITS, ( test_hidden_unit.to(device), @@ -91,4 +50,4 @@ def main(HubertExport, NetExport): if __name__ == '__main__': - main(False, True) + main(True)