Update onnx_export.py
This commit is contained in:
parent
1cc9160a1a
commit
885a1a7cb6
|
@ -16,36 +16,8 @@ def get_hubert_model():
|
|||
return model
|
||||
|
||||
|
||||
def main(HubertExport, NetExport):
|
||||
def main(NetExport):
|
||||
path = "SoVits4.0V2"
|
||||
|
||||
'''if HubertExport:
|
||||
device = torch.device("cpu")
|
||||
vec_path = "hubert/checkpoint_best_legacy_500.pt"
|
||||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[vec_path],
|
||||
suffix="",
|
||||
)
|
||||
original = models[0]
|
||||
original.eval()
|
||||
model = original
|
||||
test_input = torch.rand(1, 1, 16000)
|
||||
model(test_input)
|
||||
torch.onnx.export(model,
|
||||
test_input,
|
||||
"hubert4.0.onnx",
|
||||
export_params=True,
|
||||
opset_version=16,
|
||||
do_constant_folding=True,
|
||||
input_names=['source'],
|
||||
output_names=['embed'],
|
||||
dynamic_axes={
|
||||
'source':
|
||||
{
|
||||
2: "sample_length"
|
||||
},
|
||||
}
|
||||
)'''
|
||||
if NetExport:
|
||||
device = torch.device("cpu")
|
||||
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
|
||||
|
@ -99,4 +71,4 @@ def main(HubertExport, NetExport):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(False, True)
|
||||
main(True)
|
||||
|
|
Loading…
Reference in New Issue