89 lines
3.3 KiB
Python
89 lines
3.3 KiB
Python
|
import os
|
|||
|
import re
|
|||
|
import shutil
|
|||
|
|
|||
|
import torch
|
|||
|
|
|||
|
|
|||
|
def get_model_folder(path):
|
|||
|
model_lists = os.listdir(path)
|
|||
|
res_list = []
|
|||
|
filter_list = ["hubert", "xiaoma_pe", "hifigan", "checkpoints", ".yaml", ".zip"]
|
|||
|
for path in model_lists:
|
|||
|
if not any(word if word in path else False for word in filter_list):
|
|||
|
res_list.append(path)
|
|||
|
return res_list
|
|||
|
|
|||
|
|
|||
|
def scan(path):
|
|||
|
model_str = ""
|
|||
|
path_lists = get_model_folder(path)
|
|||
|
for i in range(0, len(path_lists)):
|
|||
|
if re.search(u'[\u4e00-\u9fa5]', path_lists[i]):
|
|||
|
print(f'{path_lists[i]}:中文路径!此项跳过')
|
|||
|
continue
|
|||
|
model_str += f"{i}:{path_lists[i]} "
|
|||
|
if (i + 1) % 5 == 0:
|
|||
|
print(f"{model_str}")
|
|||
|
model_str = ""
|
|||
|
if len(path_lists) % 5 != 0:
|
|||
|
print(model_str)
|
|||
|
return path_lists
|
|||
|
|
|||
|
|
|||
|
def simplify_pth(model_name, proj_name, output_path):
|
|||
|
model_path = f'./checkpoints/{proj_name}'
|
|||
|
checkpoint_dict = torch.load(f'{model_path}/{model_name}')
|
|||
|
torch.save({'epoch': checkpoint_dict['epoch'],
|
|||
|
'state_dict': checkpoint_dict['state_dict'],
|
|||
|
'global_step': None,
|
|||
|
'checkpoint_callback_best': None,
|
|||
|
'optimizer_states': None,
|
|||
|
'lr_schedulers': None
|
|||
|
}, output_path)
|
|||
|
|
|||
|
|
|||
|
def mkdir(paths: list):
|
|||
|
for path in paths:
|
|||
|
if not os.path.exists(path):
|
|||
|
os.mkdir(path)
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
if os.path.exists("./checkpoints"):
|
|||
|
path_list = scan("./checkpoints")
|
|||
|
else:
|
|||
|
print("请检查checkpoints文件夹是否存在")
|
|||
|
exit()
|
|||
|
a = input("\r\n请输入序号并回车:")
|
|||
|
project_name = path_list[int(a)]
|
|||
|
path_list = scan(f"./checkpoints/{path_list[int(a)]}")
|
|||
|
b = input("\r\n请输入序号并回车:")
|
|||
|
pth_name = path_list[int(b)]
|
|||
|
|
|||
|
print("\r\n选择:\r\n"
|
|||
|
"0.存储精简模型到对应模型目录(本地精简模型时推荐使用这个)\r\n"
|
|||
|
"1.存储精简模型和config.yaml到程序根目录(新建文件夹,九天毕昇上导出精简模型推荐使用这个)\r\n"
|
|||
|
"2.复制完整模型和config.yaml到程序根目录(新建文件夹,九天毕昇上导出完整模型推荐使用这个)\r\n"
|
|||
|
"输入其他退出")
|
|||
|
f = int(input("\r\n请输入序号并回车:"))
|
|||
|
if f == 0:
|
|||
|
print(f"已保存精简模型至对应模型目录")
|
|||
|
shutil.copyfile(f'./checkpoints/{project_name}/config.yaml', f"./{project_name}/config.yaml")
|
|||
|
output = f"./checkpoints/{project_name}/clean_{pth_name}"
|
|||
|
simplify_pth(pth_name, project_name, output)
|
|||
|
elif f == 1:
|
|||
|
print(f"已保存精简模型至: 根目录下新建文件夹/{project_name}")
|
|||
|
mkdir([f"./{project_name}"])
|
|||
|
shutil.copyfile(f'./checkpoints/{project_name}/config.yaml', f"./{project_name}/config.yaml")
|
|||
|
output = f"./{project_name}/clean_{pth_name}"
|
|||
|
simplify_pth(pth_name, project_name, output)
|
|||
|
elif f == 2:
|
|||
|
print(f"已保存完整模型至: 根目录下新建文件夹/{project_name}")
|
|||
|
mkdir([f"./{project_name}"])
|
|||
|
shutil.copyfile(f'./checkpoints/{project_name}/config.yaml', f"./{project_name}/config.yaml")
|
|||
|
shutil.copyfile(f'./checkpoints/{project_name}/{pth_name}', f"./{project_name}/{pth_name}")
|
|||
|
else:
|
|||
|
print("输入错误,程序退出")
|
|||
|
exit()
|