New Tiny Model Updata
This commit is contained in:
parent
a8acfa01da
commit
63f889572c
|
@ -18,7 +18,7 @@ def copyStateDict(state_dict):
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def removeOptimizer(config: str, input_model: str, output_model: str):
|
def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
|
||||||
hps = utils.get_hparams_from_file(config)
|
hps = utils.get_hparams_from_file(config)
|
||||||
|
|
||||||
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
|
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
|
||||||
|
@ -35,8 +35,8 @@ def removeOptimizer(config: str, input_model: str, output_model: str):
|
||||||
keys = []
|
keys = []
|
||||||
for k, v in new_dict_g['model'].items():
|
for k, v in new_dict_g['model'].items():
|
||||||
keys.append(k)
|
keys.append(k)
|
||||||
|
|
||||||
new_dict_g = {k: new_dict_g['model'][k] for k in keys}
|
new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys}
|
||||||
|
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
|
@ -56,7 +56,8 @@ if __name__ == "__main__":
|
||||||
default='configs/config.json')
|
default='configs/config.json')
|
||||||
parser.add_argument("-i", "--input", type=str)
|
parser.add_argument("-i", "--input", type=str)
|
||||||
parser.add_argument("-o", "--output", type=str, default=None)
|
parser.add_argument("-o", "--output", type=str, default=None)
|
||||||
|
parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
output = args.output
|
output = args.output
|
||||||
|
@ -64,6 +65,7 @@ if __name__ == "__main__":
|
||||||
if output is None:
|
if output is None:
|
||||||
import os.path
|
import os.path
|
||||||
filename, ext = os.path.splitext(args.input)
|
filename, ext = os.path.splitext(args.input)
|
||||||
output = filename + "_release" + ext
|
half = "_half" if args.half else ""
|
||||||
|
output = filename + "_release" + half + ext
|
||||||
|
|
||||||
removeOptimizer(args.config, args.input, output)
|
removeOptimizer(args.config, args.input, args.half, output)
|
|
@ -64,6 +64,7 @@
|
||||||
"speaker_embedding":false,
|
"speaker_embedding":false,
|
||||||
"vol_embedding":false,
|
"vol_embedding":false,
|
||||||
"use_depthwise_conv":false,
|
"use_depthwise_conv":false,
|
||||||
|
"flow_share_parameter": false,
|
||||||
"use_automatic_f0_prediction": true
|
"use_automatic_f0_prediction": true
|
||||||
},
|
},
|
||||||
"spk": {
|
"spk": {
|
||||||
|
|
12
models.py
12
models.py
|
@ -20,7 +20,9 @@ class ResidualCouplingBlock(nn.Module):
|
||||||
dilation_rate,
|
dilation_rate,
|
||||||
n_layers,
|
n_layers,
|
||||||
n_flows=4,
|
n_flows=4,
|
||||||
gin_channels=0):
|
gin_channels=0,
|
||||||
|
share_parameter=False
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
|
@ -31,10 +33,13 @@ class ResidualCouplingBlock(nn.Module):
|
||||||
self.gin_channels = gin_channels
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
self.flows = nn.ModuleList()
|
self.flows = nn.ModuleList()
|
||||||
|
|
||||||
|
self.wn = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=gin_channels) if share_parameter else None
|
||||||
|
|
||||||
for i in range(n_flows):
|
for i in range(n_flows):
|
||||||
self.flows.append(
|
self.flows.append(
|
||||||
modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
|
modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
|
||||||
gin_channels=gin_channels, mean_only=True))
|
gin_channels=gin_channels, mean_only=True, wn_sharing_parameter=self.wn))
|
||||||
self.flows.append(modules.Flip())
|
self.flows.append(modules.Flip())
|
||||||
|
|
||||||
def forward(self, x, x_mask, g=None, reverse=False):
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
@ -320,6 +325,7 @@ class SynthesizerTrn(nn.Module):
|
||||||
vocoder_name = "nsf-hifigan",
|
vocoder_name = "nsf-hifigan",
|
||||||
use_depthwise_conv = False,
|
use_depthwise_conv = False,
|
||||||
use_automatic_f0_prediction = True,
|
use_automatic_f0_prediction = True,
|
||||||
|
flow_share_parameter = False,
|
||||||
n_flow_layer = 4,
|
n_flow_layer = 4,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
|
@ -386,7 +392,7 @@ class SynthesizerTrn(nn.Module):
|
||||||
self.dec = Generator(h=hps)
|
self.dec = Generator(h=hps)
|
||||||
|
|
||||||
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
||||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels)
|
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
|
||||||
if self.use_automatic_f0_prediction:
|
if self.use_automatic_f0_prediction:
|
||||||
self.f0_decoder = F0Decoder(
|
self.f0_decoder = F0Decoder(
|
||||||
1,
|
1,
|
||||||
|
|
|
@ -263,7 +263,9 @@ class ResidualCouplingLayer(nn.Module):
|
||||||
n_layers,
|
n_layers,
|
||||||
p_dropout=0,
|
p_dropout=0,
|
||||||
gin_channels=0,
|
gin_channels=0,
|
||||||
mean_only=False):
|
mean_only=False,
|
||||||
|
wn_sharing_parameter=None
|
||||||
|
):
|
||||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
@ -275,7 +277,7 @@ class ResidualCouplingLayer(nn.Module):
|
||||||
self.mean_only = mean_only
|
self.mean_only = mean_only
|
||||||
|
|
||||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||||
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
|
||||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||||
self.post.weight.data.zero_()
|
self.post.weight.data.zero_()
|
||||||
self.post.bias.data.zero_()
|
self.post.bias.data.zero_()
|
||||||
|
|
|
@ -85,7 +85,7 @@ if __name__ == "__main__":
|
||||||
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768
|
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768
|
||||||
d_config_template["data"]["encoder_out_channels"] = 768
|
d_config_template["data"]["encoder_out_channels"] = 768
|
||||||
elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft':
|
elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft':
|
||||||
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 256
|
config_template["model"]["ssl_dim"] = config_template["model"]["gin_channels"] = 256
|
||||||
d_config_template["data"]["encoder_out_channels"] = 256
|
d_config_template["data"]["encoder_out_channels"] = 256
|
||||||
elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge':
|
elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge':
|
||||||
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024
|
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024
|
||||||
|
|
Loading…
Reference in New Issue