fix: slice while fully loaded into memory

This commit is contained in:
MistEO 2023-04-04 20:03:22 +08:00
parent da42bad844
commit 02de07a8f1
3 changed files with 11 additions and 4 deletions

View File

@ -21,7 +21,8 @@
"use_sr": true,
"max_speclen": 512,
"port": "8001",
"keep_ckpts": 3
"keep_ckpts": 3,
"all_in_mem": false
},
"data": {
"training_files": "filelists/train.txt",

View File

@ -79,6 +79,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
audio_norm = audio_norm[:, :lmin * self.hop_length]
return c, f0, spec, audio_norm, spk, uv
def random_slice(self, c, f0, spec, audio_norm, spk, uv):
# if spec.shape[1] < 30:
# print("skip too short audio:", filename)
# return None
@ -92,9 +96,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
def __getitem__(self, index):
if self.all_in_mem:
return self.cache[index]
return self.random_slice(*self.cache[index])
else:
return self.get_audio(self.audiopaths[index][0])
return self.random_slice(self.get_audio(self.audiopaths[index][0]))
def __len__(self):
return len(self.audiopaths)

View File

@ -67,9 +67,11 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
collate_fn = TextAudioCollate()
all_in_mem = False # If you have enough memory, turn on this option to avoid disk IO and speed up training.
all_in_mem = hps.train.all_in_mem # If you have enough memory, turn on this option to avoid disk IO and speed up training.
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps, all_in_mem=all_in_mem)
num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count()
if all_in_mem:
num_workers = 0
train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True,
batch_size=hps.train.batch_size, collate_fn=collate_fn)
if rank == 0: