chore: code cleanup by ruff suggested

This commit is contained in:
magic-akari 2023-06-26 15:04:45 +08:00
parent a5f0e911ed
commit 50a089813a
No known key found for this signature in database
GPG Key ID: EC005B1159285BDD
6 changed files with 31 additions and 25 deletions

View File

@ -24,9 +24,11 @@ def load_model_vocoder(
device='cpu',
config_path = None
):
if config_path is None: config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
else: config_file = config_path
if config_path is None:
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
else:
config_file = config_path
with open(config_file, "r") as config:
args = yaml.safe_load(config)
args = DotDict(args)

View File

@ -179,7 +179,8 @@ class Svc(object):
else:
self.feature_retrieval=False
if self.shallow_diffusion : self.nsf_hifigan_enhance = False
if self.shallow_diffusion :
self.nsf_hifigan_enhance = False
if self.nsf_hifigan_enhance:
from modules.enhancer import Enhancer
self.enhancer = Enhancer('nsf-hifigan', 'pretrain/nsf_hifigan/model',device=self.dev)
@ -442,7 +443,8 @@ class Svc(object):
datas = [data]
for k,dat in enumerate(datas):
per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length
if clip_seconds!=0: print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======')
if clip_seconds!=0:
print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======')
# padd
pad_len = int(audio_sr * pad_seconds)
dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])

View File

@ -132,8 +132,10 @@ def main():
key = "auto" if auto_predict_f0 else f"{tran}key"
cluster_name = "" if cluster_infer_ratio == 0 else f"_{cluster_infer_ratio}"
isdiffusion = "sovits"
if shallow_diffusion : isdiffusion = "sovdiff"
if only_diffusion : isdiffusion = "diff"
if shallow_diffusion :
isdiffusion = "sovdiff"
if only_diffusion :
isdiffusion = "diff"
if use_spk_mix:
spk = "spk_mix"
res_path = f'results/{clean_name}_{key}_{spk}{cluster_name}_{isdiffusion}_{f0p}.{wav_format}'

View File

@ -134,12 +134,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x

View File

@ -1,11 +1,7 @@
import logging
import multiprocessing
import time
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('numba').setLevel(logging.WARNING)
import os
import time
import torch
import torch.distributed as dist
@ -26,6 +22,9 @@ from models import (
from modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from modules.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('numba').setLevel(logging.WARNING)
torch.backends.cudnn.benchmark = True
global_step = 0
start_time = time.time()

View File

@ -45,7 +45,8 @@ def upload_mix_append_file(files,sfiles):
p = {file:100 for file in file_paths}
return file_paths,mix_model_output1.update(value=json.dumps(p,indent=2))
except Exception as e:
if debug: traceback.print_exc()
if debug:
traceback.print_exc()
raise gr.Error(e)
def mix_submit_click(js,mode):
@ -59,16 +60,19 @@ def mix_submit_click(js,mode):
path = mix_model(model_path,mix_rate,mode)
return f"成功,文件被保存在了{path}"
except Exception as e:
if debug: traceback.print_exc()
if debug:
traceback.print_exc()
raise gr.Error(e)
def updata_mix_info(files):
try:
if files is None : return mix_model_output1.update(value="")
if files is None :
return mix_model_output1.update(value="")
p = {file.name:100 for file in files}
return mix_model_output1.update(value=json.dumps(p,indent=2))
except Exception as e:
if debug: traceback.print_exc()
if debug:
traceback.print_exc()
raise gr.Error(e)
def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance,diff_model_path,diff_config_path,only_diffusion,use_spk_mix):
@ -108,7 +112,8 @@ def modelAnalysis(model_path,config_path,cluster_model_path,device,enhance,diff_
msg += i + " "
return sid.update(choices = spks,value=spks[0]), msg
except Exception as e:
if debug: traceback.print_exc()
if debug:
traceback.print_exc()
raise gr.Error(e)
@ -168,7 +173,8 @@ def vc_fn(sid, input_audio, vc_transform, auto_f0,cluster_ratio, slice_db, noise
soundfile.write(output_file, _audio, model.target_sample, format="wav")
return "Success", output_file
except Exception as e:
if debug: traceback.print_exc()
if debug:
traceback.print_exc()
raise gr.Error(e)
def tts_func(_text,_rate,_voice):
@ -176,7 +182,8 @@ def tts_func(_text,_rate,_voice):
# voice = "zh-CN-XiaoyiNeural"#女性,较高音
# voice = "zh-CN-YunxiNeural"#男性
voice = "zh-CN-YunxiNeural"#男性
if ( _voice == "" ) : voice = "zh-CN-XiaoyiNeural"
if ( _voice == "" ) :
voice = "zh-CN-XiaoyiNeural"
output_file = _text[0:10]+".wav"
# communicate = edge_tts.Communicate(_text, voice)
# await communicate.save(output_file)