Compatible the old model unit interpolation mode
This commit is contained in:
parent
1f65db3d03
commit
730ff1b995
|
@ -35,7 +35,8 @@
|
|||
"win_length": 2048,
|
||||
"n_mel_channels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": 22050
|
||||
"mel_fmax": 22050,
|
||||
"unit_interpolate_mode":"nearest"
|
||||
},
|
||||
"model": {
|
||||
"inter_channels": 192,
|
||||
|
|
|
@ -11,6 +11,7 @@ data:
|
|||
validation_files: "filelists/val.txt"
|
||||
extensions: # List of extension included in the data collection
|
||||
- wav
|
||||
unit_interpolate_mode: "nearest"
|
||||
model:
|
||||
type: 'Diffusion'
|
||||
n_layers: 20
|
||||
|
|
|
@ -31,6 +31,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||
self.filter_length = hparams.data.filter_length
|
||||
self.hop_length = hparams.data.hop_length
|
||||
self.win_length = hparams.data.win_length
|
||||
self.unit_interpolate_mode = hparams.data.unit_interpolate_mode
|
||||
self.sampling_rate = hparams.data.sampling_rate
|
||||
self.use_sr = hparams.train.use_sr
|
||||
self.spec_len = hparams.train.max_speclen
|
||||
|
@ -73,7 +74,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||
uv = torch.FloatTensor(np.array(uv,dtype=float))
|
||||
|
||||
c = torch.load(filename+ ".soft.pt")
|
||||
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
|
||||
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0], mode=self.unit_interpolate_mode)
|
||||
if self.vol_emb:
|
||||
volume_path = filename + ".vol.npy"
|
||||
volume = np.load(volume_path)
|
||||
|
|
|
@ -63,6 +63,7 @@ def get_data_loaders(args, whole_audio=False):
|
|||
spk=args.spk,
|
||||
device=args.train.cache_device,
|
||||
fp16=args.train.cache_fp16,
|
||||
unit_interpolate_mode = args.data.unit_interpolate_mode,
|
||||
use_aug=True)
|
||||
loader_train = torch.utils.data.DataLoader(
|
||||
data_train ,
|
||||
|
@ -81,6 +82,7 @@ def get_data_loaders(args, whole_audio=False):
|
|||
whole_audio=True,
|
||||
spk=args.spk,
|
||||
extensions=args.data.extensions,
|
||||
unit_interpolate_mode = args.data.unit_interpolate_mode,
|
||||
n_spk=args.model.n_spk)
|
||||
loader_valid = torch.utils.data.DataLoader(
|
||||
data_valid,
|
||||
|
@ -107,6 +109,7 @@ class AudioDataset(Dataset):
|
|||
device='cpu',
|
||||
fp16=False,
|
||||
use_aug=False,
|
||||
unit_interpolate_mode = 'left'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -118,6 +121,7 @@ class AudioDataset(Dataset):
|
|||
self.use_aug = use_aug
|
||||
self.data_buffer={}
|
||||
self.pitch_aug_dict = {}
|
||||
self.unit_interpolate_mode = unit_interpolate_mode
|
||||
# np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item()
|
||||
if load_all_data:
|
||||
print('Load all the data filelists:', filelists)
|
||||
|
@ -171,7 +175,7 @@ class AudioDataset(Dataset):
|
|||
path_units = name_ext + ".soft.pt"
|
||||
units = torch.load(path_units).to(device)
|
||||
units = units[0]
|
||||
units = repeat_expand_2d(units,f0.size(0)).transpose(0,1)
|
||||
units = repeat_expand_2d(units,f0.size(0),unit_interpolate_mode).transpose(0,1)
|
||||
|
||||
if fp16:
|
||||
mel = mel.half()
|
||||
|
@ -263,7 +267,7 @@ class AudioDataset(Dataset):
|
|||
path_units = name_ext + ".soft.pt"
|
||||
units = torch.load(path_units)
|
||||
units = units[0]
|
||||
units = repeat_expand_2d(units,f0.size(0)).transpose(0,1)
|
||||
units = repeat_expand_2d(units,f0.size(0),self.unit_interpolate_mode).transpose(0,1)
|
||||
|
||||
units = units[start_frame : start_frame + units_frame_len]
|
||||
|
||||
|
|
|
@ -140,6 +140,10 @@ class Svc(object):
|
|||
self.target_sample = self.hps_ms.data.sampling_rate
|
||||
self.hop_size = self.hps_ms.data.hop_length
|
||||
self.spk2id = self.hps_ms.spk
|
||||
try:
|
||||
self.unit_interpolate_mode = self.hps_ms.unit_interpolate_mode
|
||||
except Exception as e:
|
||||
self.unit_interpolate_mode = 'left'
|
||||
try:
|
||||
self.vol_embedding = self.hps_ms.model.vol_embedding
|
||||
except Exception as e:
|
||||
|
@ -158,6 +162,7 @@ class Svc(object):
|
|||
self.hop_size = self.diffusion_args.data.block_size
|
||||
self.spk2id = self.diffusion_args.spk
|
||||
self.speech_encoder = self.diffusion_args.data.encoder
|
||||
self.unit_interpolate_mode = self.diffusion_args.data.unit_interpolate_mode if self.diffusion_args.data.unit_interpolate_mode!=None else 'left'
|
||||
if spk_mix_enable:
|
||||
self.diffusion_model.init_spkmix(len(self.spk2id))
|
||||
else:
|
||||
|
@ -220,7 +225,7 @@ class Svc(object):
|
|||
wav16k = librosa.resample(wav, orig_sr=self.target_sample, target_sr=16000)
|
||||
wav16k = torch.from_numpy(wav16k).to(self.dev)
|
||||
c = self.hubert_model.encoder(wav16k)
|
||||
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
|
||||
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
|
||||
|
||||
if cluster_infer_ratio !=0:
|
||||
if self.feature_retrieval:
|
||||
|
@ -299,7 +304,7 @@ class Svc(object):
|
|||
audio16k = librosa.resample(audio.detach().cpu().numpy(), orig_sr=self.target_sample, target_sr=16000)
|
||||
audio16k = torch.from_numpy(audio16k).to(self.dev)
|
||||
c = self.hubert_model.encoder(audio16k)
|
||||
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
|
||||
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
|
||||
f0 = f0[:,:,None]
|
||||
c = c.transpose(-1,-2)
|
||||
audio_mel = self.diffusion_model(
|
||||
|
|
33
utils.py
33
utils.py
|
@ -377,26 +377,31 @@ def get_logger(model_dir, filename="train.log"):
|
|||
return logger
|
||||
|
||||
|
||||
# def repeat_expand_2d(content, target_len):
|
||||
# # content : [h, t]
|
||||
def repeat_expand_2d(content, target_len, mode = 'left'):
|
||||
# content : [h, t]
|
||||
return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)
|
||||
|
||||
# src_len = content.shape[-1]
|
||||
# target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
|
||||
# temp = torch.arange(src_len+1) * target_len / src_len
|
||||
# current_pos = 0
|
||||
# for i in range(target_len):
|
||||
# if i < temp[current_pos+1]:
|
||||
# target[:, i] = content[:, current_pos]
|
||||
# else:
|
||||
# current_pos += 1
|
||||
# target[:, i] = content[:, current_pos]
|
||||
|
||||
# return target
|
||||
|
||||
def repeat_expand_2d_left(content, target_len):
|
||||
# content : [h, t]
|
||||
|
||||
src_len = content.shape[-1]
|
||||
target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
|
||||
temp = torch.arange(src_len+1) * target_len / src_len
|
||||
current_pos = 0
|
||||
for i in range(target_len):
|
||||
if i < temp[current_pos+1]:
|
||||
target[:, i] = content[:, current_pos]
|
||||
else:
|
||||
current_pos += 1
|
||||
target[:, i] = content[:, current_pos]
|
||||
|
||||
return target
|
||||
|
||||
|
||||
# mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
|
||||
def repeat_expand_2d(content, target_len, mode = 'nearest'):
|
||||
def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
|
||||
# content : [h, t]
|
||||
content = content[None,:,:]
|
||||
target = F.interpolate(content,size=target_len,mode=mode)[0]
|
||||
|
|
Loading…
Reference in New Issue