diff --git a/train.py b/train.py index e242431..a6e219f 100644 --- a/train.py +++ b/train.py @@ -281,12 +281,13 @@ def evaluate(hps, generator, eval_loader, writer_eval): audio_dict = {} with torch.no_grad(): for batch_idx, items in enumerate(eval_loader): - c, f0, spec, y, spk, _, uv = items + c, f0, spec, y, spk, _, uv,volume = items g = spk[:1].cuda(0) spec, y = spec[:1].cuda(0), y[:1].cuda(0) c = c[:1].cuda(0) f0 = f0[:1].cuda(0) uv= uv[:1].cuda(0) + volume = volume[:1].cuda(0) mel = spec_to_mel_torch( spec, hps.data.filter_length, @@ -294,7 +295,7 @@ def evaluate(hps, generator, eval_loader, writer_eval): hps.data.sampling_rate, hps.data.mel_fmin, hps.data.mel_fmax) - y_hat,_ = generator.module.infer(c, f0, uv, g=g) + y_hat,_ = generator.module.infer(c, f0, uv, g=g,vol = volume) y_hat_mel = mel_spectrogram_torch( y_hat.squeeze(1).float(),