Debug config serialization

This commit is contained in:
ylzz1997 2023-06-08 02:25:32 +08:00
parent b4c8eab5af
commit 02ad34dfd7
2 changed files with 14 additions and 5 deletions

View File

@ -136,7 +136,7 @@ class Svc(object):
self.dev = torch.device(device)
self.net_g_ms = None
if not self.only_diffusion:
self.hps_ms = utils.get_hparams_from_file(config_path)
self.hps_ms = utils.get_hparams_from_file(config_path,True)
self.target_sample = self.hps_ms.data.sampling_rate
self.hop_size = self.hps_ms.data.hop_length
self.spk2id = self.hps_ms.spk

View File

@ -337,11 +337,11 @@ def get_hparams_from_dir(model_dir):
return hparams
def get_hparams_from_file(config_path):
def get_hparams_from_file(config_path, infer_mode = False):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams =HParams(**config)
hparams =HParams(**config) if not infer_mode else InferHParams(**config)
return hparams
@ -512,9 +512,18 @@ class HParams():
def get(self,index):
return self.__dict__.get(index)
def __getattr__(self,index):
return self.get(index)
class InferHParams(HParams):
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = InferHParams(**v)
self[k] = v
def __getattr__(self,index):
return self.get(index)
class Volume_Extractor:
def __init__(self, hop_size = 512):
self.hop_size = hop_size