Merge pull request #117 from MistEO/4.0

feat: Added load all to memory option for training
This commit is contained in:
謬紗特 2023-04-05 15:04:16 +08:00 committed by GitHub
commit 7e24b9ad45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 5 deletions

View File

@ -21,7 +21,8 @@
"use_sr": true,
"max_speclen": 512,
"port": "8001",
"keep_ckpts": 3
"keep_ckpts": 3,
"all_in_mem": false
},
"data": {
"training_files": "filelists/train.txt",

View File

@ -23,7 +23,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
3) computes spectrograms from audio files.
"""
def __init__(self, audiopaths, hparams):
def __init__(self, audiopaths, hparams, all_in_mem: bool = False):
self.audiopaths = load_filepaths_and_text(audiopaths)
self.max_wav_value = hparams.data.max_wav_value
self.sampling_rate = hparams.data.sampling_rate
@ -37,6 +37,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
random.seed(1234)
random.shuffle(self.audiopaths)
self.all_in_mem = all_in_mem
if self.all_in_mem:
self.cache = [self.get_audio(p[0]) for p in self.audiopaths]
def get_audio(self, filename):
filename = filename.replace("\\", "/")
@ -75,6 +79,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
audio_norm = audio_norm[:, :lmin * self.hop_length]
return c, f0, spec, audio_norm, spk, uv
def random_slice(self, c, f0, spec, audio_norm, spk, uv):
# if spec.shape[1] < 30:
# print("skip too short audio:", filename)
# return None
@ -87,7 +95,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
return c, f0, spec, audio_norm, spk, uv
def __getitem__(self, index):
return self.get_audio(self.audiopaths[index][0])
if self.all_in_mem:
return self.random_slice(*self.cache[index])
else:
return self.random_slice(*self.get_audio(self.audiopaths[index][0]))
def __len__(self):
return len(self.audiopaths)

View File

@ -67,12 +67,15 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
collate_fn = TextAudioCollate()
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps)
all_in_mem = hps.train.all_in_mem # If you have enough memory, turn on this option to avoid disk IO and speed up training.
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps, all_in_mem=all_in_mem)
num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count()
if all_in_mem:
num_workers = 0
train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True,
batch_size=hps.train.batch_size, collate_fn=collate_fn)
if rank == 0:
eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps)
eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps, all_in_mem=all_in_mem)
eval_loader = DataLoader(eval_dataset, num_workers=1, shuffle=False,
batch_size=1, pin_memory=False,
drop_last=False, collate_fn=collate_fn)