fix: slice while fully loaded into memory
This commit is contained in:
parent
da42bad844
commit
02de07a8f1
|
@ -21,7 +21,8 @@
|
|||
"use_sr": true,
|
||||
"max_speclen": 512,
|
||||
"port": "8001",
|
||||
"keep_ckpts": 3
|
||||
"keep_ckpts": 3,
|
||||
"all_in_mem": false
|
||||
},
|
||||
"data": {
|
||||
"training_files": "filelists/train.txt",
|
||||
|
|
|
@ -79,6 +79,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||
assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
|
||||
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
|
||||
audio_norm = audio_norm[:, :lmin * self.hop_length]
|
||||
|
||||
return c, f0, spec, audio_norm, spk, uv
|
||||
|
||||
def random_slice(self, c, f0, spec, audio_norm, spk, uv):
|
||||
# if spec.shape[1] < 30:
|
||||
# print("skip too short audio:", filename)
|
||||
# return None
|
||||
|
@ -92,9 +96,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
if self.all_in_mem:
|
||||
return self.cache[index]
|
||||
return self.random_slice(*self.cache[index])
|
||||
else:
|
||||
return self.get_audio(self.audiopaths[index][0])
|
||||
return self.random_slice(self.get_audio(self.audiopaths[index][0]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths)
|
||||
|
|
4
train.py
4
train.py
|
@ -67,9 +67,11 @@ def run(rank, n_gpus, hps):
|
|||
torch.manual_seed(hps.train.seed)
|
||||
torch.cuda.set_device(rank)
|
||||
collate_fn = TextAudioCollate()
|
||||
all_in_mem = False # If you have enough memory, turn on this option to avoid disk IO and speed up training.
|
||||
all_in_mem = hps.train.all_in_mem # If you have enough memory, turn on this option to avoid disk IO and speed up training.
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps, all_in_mem=all_in_mem)
|
||||
num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count()
|
||||
if all_in_mem:
|
||||
num_workers = 0
|
||||
train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True,
|
||||
batch_size=hps.train.batch_size, collate_fn=collate_fn)
|
||||
if rank == 0:
|
||||
|
|
Loading…
Reference in New Issue