diff --git a/onnx_export.py b/onnx_export.py index 5a1d335..643b063 100644 --- a/onnx_export.py +++ b/onnx_export.py @@ -16,36 +16,8 @@ def get_hubert_model(): return model -def main(HubertExport, NetExport): +def main(NetExport): path = "SoVits4.0V2" - - '''if HubertExport: - device = torch.device("cpu") - vec_path = "hubert/checkpoint_best_legacy_500.pt" - models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - [vec_path], - suffix="", - ) - original = models[0] - original.eval() - model = original - test_input = torch.rand(1, 1, 16000) - model(test_input) - torch.onnx.export(model, - test_input, - "hubert4.0.onnx", - export_params=True, - opset_version=16, - do_constant_folding=True, - input_names=['source'], - output_names=['embed'], - dynamic_axes={ - 'source': - { - 2: "sample_length" - }, - } - )''' if NetExport: device = torch.device("cpu") hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") @@ -99,4 +71,4 @@ def main(HubertExport, NetExport): if __name__ == '__main__': - main(False, True) + main(True)