Compare commits
5 Commits
1cc9160a1a
...
88b5bd2e54
Author | SHA1 | Date |
---|---|---|
Ναρουσέ·μ·γιουμεμί·Χινακάννα | 88b5bd2e54 | |
Ναρουσέ·μ·γιουμεμί·Χινακάννα | 3160e7d846 | |
Ναρουσέ·μ·γιουμεμί·Χινακάννα | aeff64eb81 | |
Ναρουσέ·μ·γιουμεμί·Χινακάννα | 5e3918982c | |
Ναρουσέ·μ·γιουμεμί·Χινακάννα | 885a1a7cb6 |
|
@ -1,51 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
from torchaudio.models.wav2vec2.utils import import_fairseq_model
|
|
||||||
from fairseq import checkpoint_utils
|
|
||||||
from onnxexport.model_onnx import SynthesizerTrn
|
from onnxexport.model_onnx import SynthesizerTrn
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
def get_hubert_model():
|
def main(NetExport):
|
||||||
vec_path = "hubert/checkpoint_best_legacy_500.pt"
|
|
||||||
print("load model(s) from {}".format(vec_path))
|
|
||||||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
||||||
[vec_path],
|
|
||||||
suffix="",
|
|
||||||
)
|
|
||||||
model = models[0]
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def main(HubertExport, NetExport):
|
|
||||||
path = "SoVits4.0V2"
|
path = "SoVits4.0V2"
|
||||||
|
|
||||||
'''if HubertExport:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
vec_path = "hubert/checkpoint_best_legacy_500.pt"
|
|
||||||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
||||||
[vec_path],
|
|
||||||
suffix="",
|
|
||||||
)
|
|
||||||
original = models[0]
|
|
||||||
original.eval()
|
|
||||||
model = original
|
|
||||||
test_input = torch.rand(1, 1, 16000)
|
|
||||||
model(test_input)
|
|
||||||
torch.onnx.export(model,
|
|
||||||
test_input,
|
|
||||||
"hubert4.0.onnx",
|
|
||||||
export_params=True,
|
|
||||||
opset_version=16,
|
|
||||||
do_constant_folding=True,
|
|
||||||
input_names=['source'],
|
|
||||||
output_names=['embed'],
|
|
||||||
dynamic_axes={
|
|
||||||
'source':
|
|
||||||
{
|
|
||||||
2: "sample_length"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)'''
|
|
||||||
if NetExport:
|
if NetExport:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
|
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
|
||||||
|
@ -55,6 +13,7 @@ def main(HubertExport, NetExport):
|
||||||
_ = SVCVITS.eval().to(device)
|
_ = SVCVITS.eval().to(device)
|
||||||
for i in SVCVITS.parameters():
|
for i in SVCVITS.parameters():
|
||||||
i.requires_grad = False
|
i.requires_grad = False
|
||||||
|
|
||||||
test_hidden_unit = torch.rand(1, 10, 256) # rand
|
test_hidden_unit = torch.rand(1, 10, 256) # rand
|
||||||
test_pitch = torch.rand(1, 10) # rand
|
test_pitch = torch.rand(1, 10) # rand
|
||||||
test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
|
test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
|
||||||
|
@ -65,14 +24,6 @@ def main(HubertExport, NetExport):
|
||||||
test_sid = torch.LongTensor([0])
|
test_sid = torch.LongTensor([0])
|
||||||
input_names = ["c", "f0", "mel2ph", "t_window", "noise", "sid"]
|
input_names = ["c", "f0", "mel2ph", "t_window", "noise", "sid"]
|
||||||
output_names = ["audio", ]
|
output_names = ["audio", ]
|
||||||
SVCVITS.eval()
|
|
||||||
|
|
||||||
SVCVITS(test_hidden_unit.to(device),
|
|
||||||
test_pitch.to(device),
|
|
||||||
test_mel2ph.to(device),
|
|
||||||
test_uv.to(device),
|
|
||||||
test_noise.to(device),
|
|
||||||
test_sid.to(device))
|
|
||||||
|
|
||||||
torch.onnx.export(SVCVITS,
|
torch.onnx.export(SVCVITS,
|
||||||
(
|
(
|
||||||
|
@ -99,4 +50,4 @@ def main(HubertExport, NetExport):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main(False, True)
|
main(True)
|
||||||
|
|
Loading…
Reference in New Issue