Updata medium whisper
This commit is contained in:
parent
3e36e2b389
commit
77f707a834
|
@ -87,8 +87,8 @@ if __name__ == "__main__":
|
|||
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 256
|
||||
d_config_template["data"]["encoder_out_channels"] = 256
|
||||
elif args.speech_encoder == "whisper-ppg" :
|
||||
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 512
|
||||
d_config_template["data"]["encoder_out_channels"] = 512
|
||||
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024
|
||||
d_config_template["data"]["encoder_out_channels"] = 1024
|
||||
|
||||
print("Writing configs/config.json")
|
||||
with open("configs/config.json", "w") as f:
|
||||
|
|
|
@ -6,7 +6,7 @@ from vencoder.whisper.audio import pad_or_trim, log_mel_spectrogram
|
|||
|
||||
|
||||
class WhisperPPG(SpeechEncoder):
|
||||
def __init__(self,vec_path = "pretrain/base.pt",device=None):
|
||||
def __init__(self,vec_path = "pretrain/medium.pt",device=None):
|
||||
if device is None:
|
||||
self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue