201 lines
7.1 KiB
Python
201 lines
7.1 KiB
Python
import time
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import torch
|
|
from torch import autocast
|
|
from torch.cuda.amp import GradScaler
|
|
|
|
from diffusion.logger import utils
|
|
from diffusion.logger.saver import Saver
|
|
|
|
|
|
def test(args, model, vocoder, loader_test, saver):
|
|
print(' [*] testing...')
|
|
model.eval()
|
|
|
|
# losses
|
|
test_loss = 0.
|
|
|
|
# intialization
|
|
num_batches = len(loader_test)
|
|
rtf_all = []
|
|
|
|
# run
|
|
with torch.no_grad():
|
|
for bidx, data in enumerate(loader_test):
|
|
fn = data['name'][0].split("/")[-1]
|
|
speaker = data['name'][0].split("/")[-2]
|
|
print('--------')
|
|
print('{}/{} - {}'.format(bidx, num_batches, fn))
|
|
|
|
# unpack data
|
|
for k in data.keys():
|
|
if not k.startswith('name'):
|
|
data[k] = data[k].to(args.device)
|
|
print('>>', data['name'][0])
|
|
|
|
# forward
|
|
st_time = time.time()
|
|
mel = model(
|
|
data['units'],
|
|
data['f0'],
|
|
data['volume'],
|
|
data['spk_id'],
|
|
gt_spec=None if model.k_step_max == model.timesteps else data['mel'],
|
|
infer=True,
|
|
infer_speedup=args.infer.speedup,
|
|
method=args.infer.method,
|
|
k_step=model.k_step_max
|
|
)
|
|
signal = vocoder.infer(mel, data['f0'])
|
|
ed_time = time.time()
|
|
|
|
# RTF
|
|
run_time = ed_time - st_time
|
|
song_time = signal.shape[-1] / args.data.sampling_rate
|
|
rtf = run_time / song_time
|
|
print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
|
|
rtf_all.append(rtf)
|
|
|
|
# loss
|
|
for i in range(args.train.batch_size):
|
|
loss = model(
|
|
data['units'],
|
|
data['f0'],
|
|
data['volume'],
|
|
data['spk_id'],
|
|
gt_spec=data['mel'],
|
|
infer=False,
|
|
k_step=model.k_step_max)
|
|
test_loss += loss.item()
|
|
|
|
# log mel
|
|
saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel)
|
|
|
|
# log audi
|
|
path_audio = data['name_ext'][0]
|
|
audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate)
|
|
if len(audio.shape) > 1:
|
|
audio = librosa.to_mono(audio)
|
|
audio = torch.from_numpy(audio).unsqueeze(0).to(signal)
|
|
saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal})
|
|
# report
|
|
test_loss /= args.train.batch_size
|
|
test_loss /= num_batches
|
|
|
|
# check
|
|
print(' [test_loss] test_loss:', test_loss)
|
|
print(' Real Time Factor', np.mean(rtf_all))
|
|
return test_loss
|
|
|
|
|
|
def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test):
|
|
# saver
|
|
saver = Saver(args, initial_global_step=initial_global_step)
|
|
|
|
# model size
|
|
params_count = utils.get_network_paras_amount({'model': model})
|
|
saver.log_info('--- model size ---')
|
|
saver.log_info(params_count)
|
|
|
|
# run
|
|
num_batches = len(loader_train)
|
|
model.train()
|
|
saver.log_info('======= start training =======')
|
|
scaler = GradScaler()
|
|
if args.train.amp_dtype == 'fp32':
|
|
dtype = torch.float32
|
|
elif args.train.amp_dtype == 'fp16':
|
|
dtype = torch.float16
|
|
elif args.train.amp_dtype == 'bf16':
|
|
dtype = torch.bfloat16
|
|
else:
|
|
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
|
|
saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step")
|
|
for epoch in range(args.train.epochs):
|
|
for batch_idx, data in enumerate(loader_train):
|
|
saver.global_step_increment()
|
|
optimizer.zero_grad()
|
|
|
|
# unpack data
|
|
for k in data.keys():
|
|
if not k.startswith('name'):
|
|
data[k] = data[k].to(args.device)
|
|
|
|
# forward
|
|
if dtype == torch.float32:
|
|
loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
|
|
aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max)
|
|
else:
|
|
with autocast(device_type=args.device, dtype=dtype):
|
|
loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
|
|
aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max)
|
|
|
|
# handle nan loss
|
|
if torch.isnan(loss):
|
|
raise ValueError(' [x] nan loss ')
|
|
else:
|
|
# backpropagate
|
|
if dtype == torch.float32:
|
|
loss.backward()
|
|
optimizer.step()
|
|
else:
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
scheduler.step()
|
|
|
|
# log loss
|
|
if saver.global_step % args.train.interval_log == 0:
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
saver.log_info(
|
|
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
|
|
epoch,
|
|
batch_idx,
|
|
num_batches,
|
|
args.env.expdir,
|
|
args.train.interval_log/saver.get_interval_time(),
|
|
current_lr,
|
|
loss.item(),
|
|
saver.get_total_time(),
|
|
saver.global_step
|
|
)
|
|
)
|
|
|
|
saver.log_value({
|
|
'train/loss': loss.item()
|
|
})
|
|
|
|
saver.log_value({
|
|
'train/lr': current_lr
|
|
})
|
|
|
|
# validation
|
|
if saver.global_step % args.train.interval_val == 0:
|
|
optimizer_save = optimizer if args.train.save_opt else None
|
|
|
|
# save latest
|
|
saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}')
|
|
last_val_step = saver.global_step - args.train.interval_val
|
|
if last_val_step % args.train.interval_force_save != 0:
|
|
saver.delete_model(postfix=f'{last_val_step}')
|
|
|
|
# run testing set
|
|
test_loss = test(args, model, vocoder, loader_test, saver)
|
|
|
|
# log loss
|
|
saver.log_info(
|
|
' --- <validation> --- \nloss: {:.3f}. '.format(
|
|
test_loss,
|
|
)
|
|
)
|
|
|
|
saver.log_value({
|
|
'validation/loss': test_loss
|
|
})
|
|
|
|
model.train()
|
|
|
|
|