140 lines
5.0 KiB
Python
140 lines
5.0 KiB
Python
import glob
|
|
import importlib
|
|
import os
|
|
|
|
import matplotlib
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributions
|
|
import torch.optim
|
|
import torch.optim
|
|
import torch.utils.data
|
|
|
|
from preprocessing.process_pipeline import File2Batch
|
|
from utils.hparams import hparams
|
|
from utils.indexed_datasets import IndexedDataset
|
|
from utils.pitch_utils import norm_interp_f0
|
|
|
|
matplotlib.use('Agg')
|
|
|
|
|
|
class SvcDataset(torch.utils.data.Dataset):
|
|
def __init__(self, prefix, shuffle=False):
|
|
super().__init__()
|
|
self.hparams = hparams
|
|
self.shuffle = shuffle
|
|
self.sort_by_len = hparams['sort_by_len']
|
|
self.sizes = None
|
|
self.data_dir = hparams['binary_data_dir']
|
|
self.prefix = prefix
|
|
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
|
|
self.indexed_ds = None
|
|
# self.name2spk_id={}
|
|
|
|
# pitch stats
|
|
f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
|
|
if os.path.exists(f0_stats_fn):
|
|
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
|
|
hparams['f0_mean'] = float(hparams['f0_mean'])
|
|
hparams['f0_std'] = float(hparams['f0_std'])
|
|
else:
|
|
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
|
|
|
|
if prefix == 'test':
|
|
if hparams['test_input_dir'] != '':
|
|
self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
|
|
else:
|
|
if hparams['num_test_samples'] > 0:
|
|
self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
|
|
self.sizes = [self.sizes[i] for i in self.avail_idxs]
|
|
|
|
@property
|
|
def _sizes(self):
|
|
return self.sizes
|
|
|
|
def _get_item(self, index):
|
|
if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
|
|
index = self.avail_idxs[index]
|
|
if self.indexed_ds is None:
|
|
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
|
|
return self.indexed_ds[index]
|
|
|
|
def __getitem__(self, index):
|
|
item = self._get_item(index)
|
|
max_frames = hparams['max_frames']
|
|
spec = torch.Tensor(item['mel'])[:max_frames]
|
|
# energy = (spec.exp() ** 2).sum(-1).sqrt()
|
|
mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
|
|
f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
|
|
hubert = torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
|
|
pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
|
|
sample = {
|
|
"id": index,
|
|
"item_name": item['item_name'],
|
|
"hubert": hubert,
|
|
"mel": spec,
|
|
"pitch": pitch,
|
|
"f0": f0,
|
|
"uv": uv,
|
|
"mel2ph": mel2ph,
|
|
"mel_nonpadding": spec.abs().sum(-1) > 0,
|
|
}
|
|
if hparams['use_energy_embed']:
|
|
sample['energy'] = item['energy']
|
|
if hparams['use_spk_id']:
|
|
sample["spk_id"] = item['spk_id']
|
|
return sample
|
|
|
|
@staticmethod
|
|
def collater(samples):
|
|
return File2Batch.processed_input2batch(samples)
|
|
|
|
@staticmethod
|
|
def load_test_inputs(test_input_dir):
|
|
inp_wav_paths = glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/*.mp3')
|
|
sizes = []
|
|
items = []
|
|
|
|
binarizer_cls = hparams.get("binarizer_cls", 'basics.base_binarizer.BaseBinarizer')
|
|
pkg = ".".join(binarizer_cls.split(".")[:-1])
|
|
cls_name = binarizer_cls.split(".")[-1]
|
|
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
|
|
from preprocessing.hubertinfer import HubertEncoder
|
|
for wav_fn in inp_wav_paths:
|
|
item_name = os.path.basename(wav_fn)
|
|
wav_fn = wav_fn
|
|
encoder = HubertEncoder(hparams['hubert_path'])
|
|
item = binarizer_cls.process_item(item_name, {'wav_fn': wav_fn}, encoder)
|
|
print(item)
|
|
items.append(item)
|
|
sizes.append(item['len'])
|
|
return items, sizes
|
|
|
|
def __len__(self):
|
|
return len(self._sizes)
|
|
|
|
def num_tokens(self, index):
|
|
return self.size(index)
|
|
|
|
def size(self, index):
|
|
"""Return an example's size as a float or tuple. This value is used when
|
|
filtering a dataset with ``--max-positions``."""
|
|
size = min(self._sizes[index], hparams['max_frames'])
|
|
return size
|
|
|
|
def ordered_indices(self):
|
|
"""Return an ordered list of indices. Batches will be constructed based
|
|
on this order."""
|
|
if self.shuffle:
|
|
indices = np.random.permutation(len(self))
|
|
if self.sort_by_len:
|
|
indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
|
|
# 先random, 然后稳定排序, 保证排序后同长度的数据顺序是依照random permutation的 (被其随机打乱).
|
|
else:
|
|
indices = np.arange(len(self))
|
|
return indices
|
|
|
|
@property
|
|
def num_workers(self):
|
|
return int(os.getenv('NUM_WORKERS', hparams['ds_workers']))
|