Update onnx_export.py

This commit is contained in:
Ναρουσέ·μ·γιουμεμί·Χινακάννα 2023-03-15 15:44:26 +08:00 committed by GitHub
parent 1cc9160a1a
commit 885a1a7cb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 30 deletions

View File

@ -16,36 +16,8 @@ def get_hubert_model():
return model
def main(HubertExport, NetExport):
def main(NetExport):
path = "SoVits4.0V2"
'''if HubertExport:
device = torch.device("cpu")
vec_path = "hubert/checkpoint_best_legacy_500.pt"
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[vec_path],
suffix="",
)
original = models[0]
original.eval()
model = original
test_input = torch.rand(1, 1, 16000)
model(test_input)
torch.onnx.export(model,
test_input,
"hubert4.0.onnx",
export_params=True,
opset_version=16,
do_constant_folding=True,
input_names=['source'],
output_names=['embed'],
dynamic_axes={
'source':
{
2: "sample_length"
},
}
)'''
if NetExport:
device = torch.device("cpu")
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
@ -99,4 +71,4 @@ def main(HubertExport, NetExport):
if __name__ == '__main__':
main(False, True)
main(True)