New Tiny Model Updata

This commit is contained in:
YuriHead 2023-07-01 04:27:18 +08:00
parent a8acfa01da
commit 63f889572c
5 changed files with 23 additions and 12 deletions

View File

@ -18,7 +18,7 @@ def copyStateDict(state_dict):
return new_state_dict
def removeOptimizer(config: str, input_model: str, output_model: str):
def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
hps = utils.get_hparams_from_file(config)
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
@ -35,8 +35,8 @@ def removeOptimizer(config: str, input_model: str, output_model: str):
keys = []
for k, v in new_dict_g['model'].items():
keys.append(k)
new_dict_g = {k: new_dict_g['model'][k] for k in keys}
new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys}
torch.save(
{
@ -56,7 +56,8 @@ if __name__ == "__main__":
default='configs/config.json')
parser.add_argument("-i", "--input", type=str)
parser.add_argument("-o", "--output", type=str, default=None)
parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16')
args = parser.parse_args()
output = args.output
@ -64,6 +65,7 @@ if __name__ == "__main__":
if output is None:
import os.path
filename, ext = os.path.splitext(args.input)
output = filename + "_release" + ext
half = "_half" if args.half else ""
output = filename + "_release" + half + ext
removeOptimizer(args.config, args.input, output)
removeOptimizer(args.config, args.input, args.half, output)

View File

@ -64,6 +64,7 @@
"speaker_embedding":false,
"vol_embedding":false,
"use_depthwise_conv":false,
"flow_share_parameter": false,
"use_automatic_f0_prediction": true
},
"spk": {

View File

@ -20,7 +20,9 @@ class ResidualCouplingBlock(nn.Module):
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0):
gin_channels=0,
share_parameter=False
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
@ -31,10 +33,13 @@ class ResidualCouplingBlock(nn.Module):
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
self.wn = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=gin_channels) if share_parameter else None
for i in range(n_flows):
self.flows.append(
modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
gin_channels=gin_channels, mean_only=True))
gin_channels=gin_channels, mean_only=True, wn_sharing_parameter=self.wn))
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
@ -320,6 +325,7 @@ class SynthesizerTrn(nn.Module):
vocoder_name = "nsf-hifigan",
use_depthwise_conv = False,
use_automatic_f0_prediction = True,
flow_share_parameter = False,
n_flow_layer = 4,
**kwargs):
@ -386,7 +392,7 @@ class SynthesizerTrn(nn.Module):
self.dec = Generator(h=hps)
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
if self.use_automatic_f0_prediction:
self.f0_decoder = F0Decoder(
1,

View File

@ -263,7 +263,9 @@ class ResidualCouplingLayer(nn.Module):
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False):
mean_only=False,
wn_sharing_parameter=None
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
@ -275,7 +277,7 @@ class ResidualCouplingLayer(nn.Module):
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()

View File

@ -85,7 +85,7 @@ if __name__ == "__main__":
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768
d_config_template["data"]["encoder_out_channels"] = 768
elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft':
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 256
config_template["model"]["ssl_dim"] = config_template["model"]["gin_channels"] = 256
d_config_template["data"]["encoder_out_channels"] = 256
elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge':
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024