Debug config serialization
This commit is contained in:
parent
b4c8eab5af
commit
02ad34dfd7
|
@ -136,7 +136,7 @@ class Svc(object):
|
|||
self.dev = torch.device(device)
|
||||
self.net_g_ms = None
|
||||
if not self.only_diffusion:
|
||||
self.hps_ms = utils.get_hparams_from_file(config_path)
|
||||
self.hps_ms = utils.get_hparams_from_file(config_path,True)
|
||||
self.target_sample = self.hps_ms.data.sampling_rate
|
||||
self.hop_size = self.hps_ms.data.hop_length
|
||||
self.spk2id = self.hps_ms.spk
|
||||
|
|
17
utils.py
17
utils.py
|
@ -337,11 +337,11 @@ def get_hparams_from_dir(model_dir):
|
|||
return hparams
|
||||
|
||||
|
||||
def get_hparams_from_file(config_path):
|
||||
def get_hparams_from_file(config_path, infer_mode = False):
|
||||
with open(config_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
hparams =HParams(**config)
|
||||
hparams =HParams(**config) if not infer_mode else InferHParams(**config)
|
||||
return hparams
|
||||
|
||||
|
||||
|
@ -512,9 +512,18 @@ class HParams():
|
|||
def get(self,index):
|
||||
return self.__dict__.get(index)
|
||||
|
||||
def __getattr__(self,index):
|
||||
return self.get(index)
|
||||
|
||||
class InferHParams(HParams):
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if type(v) == dict:
|
||||
v = InferHParams(**v)
|
||||
self[k] = v
|
||||
|
||||
def __getattr__(self,index):
|
||||
return self.get(index)
|
||||
|
||||
|
||||
class Volume_Extractor:
|
||||
def __init__(self, hop_size = 512):
|
||||
self.hop_size = hop_size
|
||||
|
|
Loading…
Reference in New Issue