Merge remote-tracking branch 'origin/4.0' into 4.0

# Conflicts:
#	webUI.py
This commit is contained in:
limbang 2023-04-13 00:14:24 +08:00
commit efd3311da2
No known key found for this signature in database
GPG Key ID: 7BA8ED77415F0201
2 changed files with 32 additions and 5 deletions

View File

@ -6,6 +6,7 @@ import os
import time import time
from pathlib import Path from pathlib import Path
from inference import slicer from inference import slicer
import gc
import librosa import librosa
import numpy as np import numpy as np
@ -221,6 +222,16 @@ class Svc(object):
# 清理显存 # 清理显存
torch.cuda.empty_cache() torch.cuda.empty_cache()
def unload_model(self):
# 卸载模型
self.net_g_ms = self.net_g_ms.to("cpu")
del self.net_g_ms
if hasattr(self,"enhancer"):
self.enhancer.enhancer = self.enhancer.enhancer.to("cpu")
del self.enhancer.enhancer
del self.enhancer
gc.collect()
def slice_inference(self, def slice_inference(self,
raw_audio_path, raw_audio_path,
spk, spk,

View File

@ -17,6 +17,7 @@ from scipy.io import wavfile
import librosa import librosa
import torch import torch
import time import time
import traceback
logging.getLogger('numba').setLevel(logging.WARNING) logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('markdown_it').setLevel(logging.WARNING) logging.getLogger('markdown_it').setLevel(logging.WARNING)
@ -28,15 +29,16 @@ model = None
spk = None spk = None
debug = False debug = False
cuda = [] cuda = {}
if torch.cuda.is_available(): if torch.cuda.is_available():
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
device_name = torch.cuda.get_device_properties(i).name device_name = torch.cuda.get_device_properties(i).name
cuda.append(f"CUDA:{i} {device_name}") cuda[f"CUDA:{i} {device_name}"] = f"cuda:{i}"
def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance): def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance):
global model global model
try: try:
device = cuda[device] if "CUDA" in device else device
model = Svc(model_path.name, config_path.name, device=device if device!="Auto" else None, cluster_model_path = cluster_model_path.name if cluster_model_path != None else "",nsf_hifigan_enhance=enhance) model = Svc(model_path.name, config_path.name, device=device if device!="Auto" else None, cluster_model_path = cluster_model_path.name if cluster_model_path != None else "",nsf_hifigan_enhance=enhance)
spks = list(model.spk2id.keys()) spks = list(model.spk2id.keys())
device_name = torch.cuda.get_device_properties(model.dev).name if "cuda" in str(model.dev) else str(model.dev) device_name = torch.cuda.get_device_properties(model.dev).name if "cuda" in str(model.dev) else str(model.dev)
@ -50,6 +52,7 @@ def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance):
msg += i + " " msg += i + " "
return sid.update(choices = spks,value=spks[0]), msg return sid.update(choices = spks,value=spks[0]), msg
except Exception as e: except Exception as e:
if debug: traceback.print_exc()
raise gr.Error(e) raise gr.Error(e)
@ -58,6 +61,7 @@ def modelUnload():
if model is None: if model is None:
return sid.update(choices = [],value=""),"没有模型需要卸载!" return sid.update(choices = [],value=""),"没有模型需要卸载!"
else: else:
model.unload_model()
model = None model = None
torch.cuda.empty_cache() torch.cuda.empty_cache()
return sid.update(choices = [],value=""),"模型卸载完毕!" return sid.update(choices = [],value=""),"模型卸载完毕!"
@ -88,8 +92,10 @@ def vc_fn(sid, input_audio, vc_transform, auto_f0,cluster_ratio, slice_db, noise
soundfile.write(output_file, _audio, model.target_sample, format="wav") soundfile.write(output_file, _audio, model.target_sample, format="wav")
return f"推理成功音频文件保存为results/{filename}", (model.target_sample, _audio) return f"推理成功音频文件保存为results/{filename}", (model.target_sample, _audio)
except Exception as e: except Exception as e:
if debug: traceback.print_exc()
raise gr.Error(e) raise gr.Error(e)
except Exception as e: except Exception as e:
if debug: traceback.print_exc()
raise gr.Error(e) raise gr.Error(e)
@ -141,6 +147,9 @@ def vc_fn2(sid, input_audio, vc_transform, auto_f0,cluster_ratio, slice_db, nois
os.remove(save_path2) os.remove(save_path2)
return a,b return a,b
def debug_change():
global debug
debug = debug_button.value
with gr.Blocks( with gr.Blocks(
theme=gr.themes.Base( theme=gr.themes.Base(
@ -162,7 +171,7 @@ with gr.Blocks(
model_path = gr.File(label="选择模型文件") model_path = gr.File(label="选择模型文件")
config_path = gr.File(label="选择配置文件") config_path = gr.File(label="选择配置文件")
cluster_model_path = gr.File(label="选择聚类模型文件(没有可以不选)") cluster_model_path = gr.File(label="选择聚类模型文件(没有可以不选)")
device = gr.Dropdown(label="推理设备默认为自动选择CPU和GPU", choices=["Auto",*cuda,"CPU"], value="Auto") device = gr.Dropdown(label="推理设备默认为自动选择CPU和GPU", choices=["Auto",*cuda.keys(),"CPU"], value="Auto")
enhance = gr.Checkbox(label="是否使用NSF_HIFIGAN增强,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭", value=False) enhance = gr.Checkbox(label="是否使用NSF_HIFIGAN增强,该选项对部分训练集少的模型有一定的音质增强效果,但是对训练好的模型有反面效果,默认关闭", value=False)
with gr.Column(): with gr.Column():
gr.Markdown(value=""" gr.Markdown(value="""
@ -205,8 +214,15 @@ with gr.Blocks(
vc_output1 = gr.Textbox(label="Output Message") vc_output1 = gr.Textbox(label="Output Message")
with gr.Column(): with gr.Column():
vc_output2 = gr.Audio(label="Output Audio", interactive=False) vc_output2 = gr.Audio(label="Output Audio", interactive=False)
with gr.Row(variant="panel"):
with gr.Column():
gr.Markdown(value="""
<font size=2> WebUI设置</font>
""")
debug_button = gr.Checkbox(label="Debug模式如果向社区反馈BUG需要打开打开后控制台可以显示具体错误提示", value=debug)
vc_submit.click(vc_fn, [sid, vc_input3, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,F0_mean_pooling,enhancer_adaptive_key], [vc_output1, vc_output2]) vc_submit.click(vc_fn, [sid, vc_input3, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,F0_mean_pooling,enhancer_adaptive_key], [vc_output1, vc_output2])
vc_submit2.click(vc_fn2, [sid, vc_input3, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,text2tts,tts_rate,F0_mean_pooling,enhancer_adaptive_key], [vc_output1, vc_output2])
debug_button.change(debug_change,[],[])
vc_submit2.click(vc_fn2, [sid, vc_input3, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,text2tts,tts_rate,tts_voice,F0_mean_pooling,enhancer_adaptive_key], [vc_output1, vc_output2]) vc_submit2.click(vc_fn2, [sid, vc_input3, vc_transform,auto_f0,cluster_ratio, slice_db, noise_scale,pad_seconds,cl_num,lg_num,lgr_num,text2tts,tts_rate,tts_voice,F0_mean_pooling,enhancer_adaptive_key], [vc_output1, vc_output2])
model_load_button.click(modelAnalysis,[model_path,config_path,cluster_model_path,device,enhance],[sid,sid_output]) model_load_button.click(modelAnalysis,[model_path,config_path,cluster_model_path,device,enhance],[sid,sid_output])
model_unload_button.click(modelUnload,[],[sid,sid_output]) model_unload_button.click(modelUnload,[],[sid,sid_output])