Compatible the old model unit interpolation mode

This commit is contained in:
ylzz1997 2023-06-05 13:15:44 +08:00
parent 1f65db3d03
commit 730ff1b995
6 changed files with 37 additions and 20 deletions

View File

@ -35,7 +35,8 @@
"win_length": 2048,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": 22050
"mel_fmax": 22050,
"unit_interpolate_mode":"nearest"
},
"model": {
"inter_channels": 192,

View File

@ -11,6 +11,7 @@ data:
validation_files: "filelists/val.txt"
extensions: # List of extension included in the data collection
- wav
unit_interpolate_mode: "nearest"
model:
type: 'Diffusion'
n_layers: 20

View File

@ -31,6 +31,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
self.filter_length = hparams.data.filter_length
self.hop_length = hparams.data.hop_length
self.win_length = hparams.data.win_length
self.unit_interpolate_mode = hparams.data.unit_interpolate_mode
self.sampling_rate = hparams.data.sampling_rate
self.use_sr = hparams.train.use_sr
self.spec_len = hparams.train.max_speclen
@ -73,7 +74,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
uv = torch.FloatTensor(np.array(uv,dtype=float))
c = torch.load(filename+ ".soft.pt")
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0], mode=self.unit_interpolate_mode)
if self.vol_emb:
volume_path = filename + ".vol.npy"
volume = np.load(volume_path)

View File

@ -63,6 +63,7 @@ def get_data_loaders(args, whole_audio=False):
spk=args.spk,
device=args.train.cache_device,
fp16=args.train.cache_fp16,
unit_interpolate_mode = args.data.unit_interpolate_mode,
use_aug=True)
loader_train = torch.utils.data.DataLoader(
data_train ,
@ -81,6 +82,7 @@ def get_data_loaders(args, whole_audio=False):
whole_audio=True,
spk=args.spk,
extensions=args.data.extensions,
unit_interpolate_mode = args.data.unit_interpolate_mode,
n_spk=args.model.n_spk)
loader_valid = torch.utils.data.DataLoader(
data_valid,
@ -107,6 +109,7 @@ class AudioDataset(Dataset):
device='cpu',
fp16=False,
use_aug=False,
unit_interpolate_mode = 'left'
):
super().__init__()
@ -118,6 +121,7 @@ class AudioDataset(Dataset):
self.use_aug = use_aug
self.data_buffer={}
self.pitch_aug_dict = {}
self.unit_interpolate_mode = unit_interpolate_mode
# np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item()
if load_all_data:
print('Load all the data filelists:', filelists)
@ -171,7 +175,7 @@ class AudioDataset(Dataset):
path_units = name_ext + ".soft.pt"
units = torch.load(path_units).to(device)
units = units[0]
units = repeat_expand_2d(units,f0.size(0)).transpose(0,1)
units = repeat_expand_2d(units,f0.size(0),unit_interpolate_mode).transpose(0,1)
if fp16:
mel = mel.half()
@ -263,7 +267,7 @@ class AudioDataset(Dataset):
path_units = name_ext + ".soft.pt"
units = torch.load(path_units)
units = units[0]
units = repeat_expand_2d(units,f0.size(0)).transpose(0,1)
units = repeat_expand_2d(units,f0.size(0),self.unit_interpolate_mode).transpose(0,1)
units = units[start_frame : start_frame + units_frame_len]

View File

@ -140,6 +140,10 @@ class Svc(object):
self.target_sample = self.hps_ms.data.sampling_rate
self.hop_size = self.hps_ms.data.hop_length
self.spk2id = self.hps_ms.spk
try:
self.unit_interpolate_mode = self.hps_ms.unit_interpolate_mode
except Exception as e:
self.unit_interpolate_mode = 'left'
try:
self.vol_embedding = self.hps_ms.model.vol_embedding
except Exception as e:
@ -158,6 +162,7 @@ class Svc(object):
self.hop_size = self.diffusion_args.data.block_size
self.spk2id = self.diffusion_args.spk
self.speech_encoder = self.diffusion_args.data.encoder
self.unit_interpolate_mode = self.diffusion_args.data.unit_interpolate_mode if self.diffusion_args.data.unit_interpolate_mode!=None else 'left'
if spk_mix_enable:
self.diffusion_model.init_spkmix(len(self.spk2id))
else:
@ -220,7 +225,7 @@ class Svc(object):
wav16k = librosa.resample(wav, orig_sr=self.target_sample, target_sr=16000)
wav16k = torch.from_numpy(wav16k).to(self.dev)
c = self.hubert_model.encoder(wav16k)
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
if cluster_infer_ratio !=0:
if self.feature_retrieval:
@ -299,7 +304,7 @@ class Svc(object):
audio16k = librosa.resample(audio.detach().cpu().numpy(), orig_sr=self.target_sample, target_sr=16000)
audio16k = torch.from_numpy(audio16k).to(self.dev)
c = self.hubert_model.encoder(audio16k)
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
f0 = f0[:,:,None]
c = c.transpose(-1,-2)
audio_mel = self.diffusion_model(

View File

@ -377,26 +377,31 @@ def get_logger(model_dir, filename="train.log"):
return logger
# def repeat_expand_2d(content, target_len):
# # content : [h, t]
def repeat_expand_2d(content, target_len, mode = 'left'):
# content : [h, t]
return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)
# src_len = content.shape[-1]
# target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
# temp = torch.arange(src_len+1) * target_len / src_len
# current_pos = 0
# for i in range(target_len):
# if i < temp[current_pos+1]:
# target[:, i] = content[:, current_pos]
# else:
# current_pos += 1
# target[:, i] = content[:, current_pos]
# return target
def repeat_expand_2d_left(content, target_len):
# content : [h, t]
src_len = content.shape[-1]
target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
temp = torch.arange(src_len+1) * target_len / src_len
current_pos = 0
for i in range(target_len):
if i < temp[current_pos+1]:
target[:, i] = content[:, current_pos]
else:
current_pos += 1
target[:, i] = content[:, current_pos]
return target
# mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
def repeat_expand_2d(content, target_len, mode = 'nearest'):
def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
# content : [h, t]
content = content[None,:,:]
target = F.interpolate(content,size=target_len,mode=mode)[0]