Merge pull request #117 from MistEO/4.0
feat: Added load all to memory option for training
This commit is contained in:
commit
7e24b9ad45
|
@ -21,7 +21,8 @@
|
|||
"use_sr": true,
|
||||
"max_speclen": 512,
|
||||
"port": "8001",
|
||||
"keep_ckpts": 3
|
||||
"keep_ckpts": 3,
|
||||
"all_in_mem": false
|
||||
},
|
||||
"data": {
|
||||
"training_files": "filelists/train.txt",
|
||||
|
|
|
@ -23,7 +23,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||
3) computes spectrograms from audio files.
|
||||
"""
|
||||
|
||||
def __init__(self, audiopaths, hparams):
|
||||
def __init__(self, audiopaths, hparams, all_in_mem: bool = False):
|
||||
self.audiopaths = load_filepaths_and_text(audiopaths)
|
||||
self.max_wav_value = hparams.data.max_wav_value
|
||||
self.sampling_rate = hparams.data.sampling_rate
|
||||
|
@ -37,6 +37,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||
|
||||
random.seed(1234)
|
||||
random.shuffle(self.audiopaths)
|
||||
|
||||
self.all_in_mem = all_in_mem
|
||||
if self.all_in_mem:
|
||||
self.cache = [self.get_audio(p[0]) for p in self.audiopaths]
|
||||
|
||||
def get_audio(self, filename):
|
||||
filename = filename.replace("\\", "/")
|
||||
|
@ -75,6 +79,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||
assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
|
||||
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
|
||||
audio_norm = audio_norm[:, :lmin * self.hop_length]
|
||||
|
||||
return c, f0, spec, audio_norm, spk, uv
|
||||
|
||||
def random_slice(self, c, f0, spec, audio_norm, spk, uv):
|
||||
# if spec.shape[1] < 30:
|
||||
# print("skip too short audio:", filename)
|
||||
# return None
|
||||
|
@ -87,7 +95,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||
return c, f0, spec, audio_norm, spk, uv
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.get_audio(self.audiopaths[index][0])
|
||||
if self.all_in_mem:
|
||||
return self.random_slice(*self.cache[index])
|
||||
else:
|
||||
return self.random_slice(*self.get_audio(self.audiopaths[index][0]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths)
|
||||
|
|
7
train.py
7
train.py
|
@ -67,12 +67,15 @@ def run(rank, n_gpus, hps):
|
|||
torch.manual_seed(hps.train.seed)
|
||||
torch.cuda.set_device(rank)
|
||||
collate_fn = TextAudioCollate()
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps)
|
||||
all_in_mem = hps.train.all_in_mem # If you have enough memory, turn on this option to avoid disk IO and speed up training.
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps, all_in_mem=all_in_mem)
|
||||
num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count()
|
||||
if all_in_mem:
|
||||
num_workers = 0
|
||||
train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True,
|
||||
batch_size=hps.train.batch_size, collate_fn=collate_fn)
|
||||
if rank == 0:
|
||||
eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps)
|
||||
eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps, all_in_mem=all_in_mem)
|
||||
eval_loader = DataLoader(eval_dataset, num_workers=1, shuffle=False,
|
||||
batch_size=1, pin_memory=False,
|
||||
drop_last=False, collate_fn=collate_fn)
|
||||
|
|
Loading…
Reference in New Issue