This commit is contained in:
ylzz1997 2023-05-16 13:17:51 +08:00
parent ddc594a8e1
commit 76e06158fd
12 changed files with 435 additions and 32 deletions

0
configs/diffusion.yaml Normal file
View File

View File

@ -0,0 +1,48 @@
data:
sampling_rate: 44100
block_size: 512 # Equal to hop_length
duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip
encoder: 'vec768l12' # 'hubertsoft', 'vec256l9', 'vec768l12'
cnhubertsoft_gate: 10
encoder_sample_rate: 16000
encoder_hop_size: 320
encoder_out_channels: 768 # 256 if using 'hubertsoft'
train_path: dataset/44k # Create a folder named "audio" under this path and put the audio clip in it
filelists_path: filelists/ # FileLists path
extensions: # List of extension included in the data collection
- wav
model:
type: 'Diffusion'
n_layers: 20
n_chans: 512
n_hidden: 256
use_pitch_aug: true
n_spk: 1 # max number of different speakers
device: cuda
vocoder:
type: 'nsf-hifigan'
ckpt: 'pretrain/nsf_hifigan/model'
infer:
speedup: 10
method: 'dpm-solver' # 'pndm' or 'dpm-solver'
env:
expdir: exp/diffusion-test
gpu_id: 0
train:
num_workers: 2 # If your cpu and gpu are both very strong, set to 0 may be faster!
amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu)
batch_size: 48
cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow
cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu
cache_fp16: true
epochs: 100000
interval_log: 10
interval_val: 2000
interval_force_save: 10000
lr: 0.0002
decay_step: 100000
gamma: 0.5
weight_decay: 0
save_opt: false
spk:
'nyaru': 0

View File

@ -51,7 +51,7 @@ def traverse_dir(
def get_data_loaders(args, whole_audio=False): def get_data_loaders(args, whole_audio=False):
data_train = AudioDataset( data_train = AudioDataset(
args.data.train_path, filelists_path = args.filelists_path,
waveform_sec=args.data.duration, waveform_sec=args.data.duration,
hop_size=args.data.block_size, hop_size=args.data.block_size,
sample_rate=args.data.sampling_rate, sample_rate=args.data.sampling_rate,
@ -71,7 +71,7 @@ def get_data_loaders(args, whole_audio=False):
pin_memory=True if args.train.cache_device=='cpu' else False pin_memory=True if args.train.cache_device=='cpu' else False
) )
data_valid = AudioDataset( data_valid = AudioDataset(
args.data.valid_path, filelists_path = args.filelists_path,
waveform_sec=args.data.duration, waveform_sec=args.data.duration,
hop_size=args.data.block_size, hop_size=args.data.block_size,
sample_rate=args.data.sampling_rate, sample_rate=args.data.sampling_rate,
@ -92,7 +92,7 @@ def get_data_loaders(args, whole_audio=False):
class AudioDataset(Dataset): class AudioDataset(Dataset):
def __init__( def __init__(
self, self,
path_root, filelists,
waveform_sec, waveform_sec,
hop_size, hop_size,
sample_rate, sample_rate,
@ -109,7 +109,7 @@ class AudioDataset(Dataset):
self.waveform_sec = waveform_sec self.waveform_sec = waveform_sec
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.hop_size = hop_size self.hop_size = hop_size
self.path_root = path_root self.filelists = filelists
self.paths = traverse_dir( self.paths = traverse_dir(
os.path.join(path_root, 'audio'), os.path.join(path_root, 'audio'),
extensions=extensions, extensions=extensions,

View File

150
diffusion/logger/saver.py Normal file
View File

@ -0,0 +1,150 @@
'''
author: wayn391@mastertones
'''
import os
import json
import time
import yaml
import datetime
import torch
import matplotlib.pyplot as plt
from . import utils
from torch.utils.tensorboard import SummaryWriter
class Saver(object):
def __init__(
self,
args,
initial_global_step=-1):
self.expdir = args.env.expdir
self.sample_rate = args.data.sampling_rate
# cold start
self.global_step = initial_global_step
self.init_time = time.time()
self.last_time = time.time()
# makedirs
os.makedirs(self.expdir, exist_ok=True)
# path
self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
# ckpt
os.makedirs(self.expdir, exist_ok=True)
# writer
self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
# save config
path_config = os.path.join(self.expdir, 'config.yaml')
with open(path_config, "w") as out_config:
yaml.dump(dict(args), out_config)
def log_info(self, msg):
'''log method'''
if isinstance(msg, dict):
msg_list = []
for k, v in msg.items():
tmp_str = ''
if isinstance(v, int):
tmp_str = '{}: {:,}'.format(k, v)
else:
tmp_str = '{}: {}'.format(k, v)
msg_list.append(tmp_str)
msg_str = '\n'.join(msg_list)
else:
msg_str = msg
# dsplay
print(msg_str)
# save
with open(self.path_log_info, 'a') as fp:
fp.write(msg_str+'\n')
def log_value(self, dict):
for k, v in dict.items():
self.writer.add_scalar(k, v, self.global_step)
def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
spec = spec_cat[0]
if isinstance(spec, torch.Tensor):
spec = spec.cpu().numpy()
fig = plt.figure(figsize=(12, 9))
plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
plt.tight_layout()
self.writer.add_figure(name, fig, self.global_step)
def log_audio(self, dict):
for k, v in dict.items():
self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
def get_interval_time(self, update=True):
cur_time = time.time()
time_interval = cur_time - self.last_time
if update:
self.last_time = cur_time
return time_interval
def get_total_time(self, to_str=True):
total_time = time.time() - self.init_time
if to_str:
total_time = str(datetime.timedelta(
seconds=total_time))[:-5]
return total_time
def save_model(
self,
model,
optimizer,
name='model',
postfix='',
to_json=False):
# path
if postfix:
postfix = '_' + postfix
path_pt = os.path.join(
self.expdir , name+postfix+'.pt')
# check
print(' [*] model checkpoint saved: {}'.format(path_pt))
# save
if optimizer is not None:
torch.save({
'global_step': self.global_step,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}, path_pt)
else:
torch.save({
'global_step': self.global_step,
'model': model.state_dict()}, path_pt)
# to json
if to_json:
path_json = os.path.join(
self.expdir , name+'.json')
utils.to_json(path_params, path_json)
def delete_model(self, name='model', postfix=''):
# path
if postfix:
postfix = '_' + postfix
path_pt = os.path.join(
self.expdir , name+postfix+'.pt')
# delete
if os.path.exists(path_pt):
os.remove(path_pt)
print(' [*] model checkpoint deleted: {}'.format(path_pt))
def global_step_increment(self):
self.global_step += 1

126
diffusion/logger/utils.py Normal file
View File

@ -0,0 +1,126 @@
import os
import yaml
import json
import pickle
import torch
def traverse_dir(
root_dir,
extensions,
amount=None,
str_include=None,
str_exclude=None,
is_pure=False,
is_sort=False,
is_ext=True):
file_list = []
cnt = 0
for root, _, files in os.walk(root_dir):
for file in files:
if any([file.endswith(f".{ext}") for ext in extensions]):
# path
mix_path = os.path.join(root, file)
pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
# amount
if (amount is not None) and (cnt == amount):
if is_sort:
file_list.sort()
return file_list
# check string
if (str_include is not None) and (str_include not in pure_path):
continue
if (str_exclude is not None) and (str_exclude in pure_path):
continue
if not is_ext:
ext = pure_path.split('.')[-1]
pure_path = pure_path[:-(len(ext)+1)]
file_list.append(pure_path)
cnt += 1
if is_sort:
file_list.sort()
return file_list
class DotDict(dict):
def __getattr__(*args):
val = dict.get(*args)
return DotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def get_network_paras_amount(model_dict):
info = dict()
for model_name, model in model_dict.items():
# all_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
info[model_name] = trainable_params
return info
def load_config(path_config):
with open(path_config, "r") as config:
args = yaml.safe_load(config)
args = DotDict(args)
# print(args)
return args
def save_config(path_config,config):
config = dict(config)
with open(path_config, "w") as f:
yaml.dump(config, f)
def to_json(path_params, path_json):
params = torch.load(path_params, map_location=torch.device('cpu'))
raw_state_dict = {}
for k, v in params.items():
val = v.flatten().numpy().tolist()
raw_state_dict[k] = val
with open(path_json, 'w') as outfile:
json.dump(raw_state_dict, outfile,indent= "\t")
def convert_tensor_to_numpy(tensor, is_squeeze=True):
if is_squeeze:
tensor = tensor.squeeze()
if tensor.requires_grad:
tensor = tensor.detach()
if tensor.is_cuda:
tensor = tensor.cpu()
return tensor.numpy()
def load_model(
expdir,
model,
optimizer,
name='model',
postfix='',
device='cpu'):
if postfix == '':
postfix = '_' + postfix
path = os.path.join(expdir, name+postfix)
path_pt = traverse_dir(expdir, ['pt'], is_ext=False)
global_step = 0
if len(path_pt) > 0:
steps = [s[len(path):] for s in path_pt]
maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
if maxstep >= 0:
path_pt = path+str(maxstep)+'.pt'
else:
path_pt = path+'best.pt'
print(' [*] restoring model from', path_pt)
ckpt = torch.load(path_pt, map_location=torch.device(device))
global_step = ckpt['global_step']
model.load_state_dict(ckpt['model'], strict=False)
if ckpt.get('optimizer') != None:
optimizer.load_state_dict(ckpt['optimizer'])
return global_step, model, optimizer

View File

@ -1,6 +1,6 @@
import torch import torch
from nsf_hifigan.nvSTFT import STFT from vdecoder.nsf_hifigan.nvSTFT import STFT
from nsf_hifigan.models import load_model from vdecoder.nsf_hifigan.models import load_model,load_config
from torchaudio.transforms import Resample from torchaudio.transforms import Resample
@ -31,7 +31,7 @@ class Vocoder:
key_str = str(sample_rate) key_str = str(sample_rate)
if key_str not in self.resample_kernel: if key_str not in self.resample_kernel:
self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device) self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device)
audio_res = self.resample_kernel[key_str](audio) audio_res = self.resample_kernel[key_str](audio)
# extract # extract
mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
@ -49,8 +49,9 @@ class NsfHifiGAN(torch.nn.Module):
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device self.device = device
print('| Load HifiGAN: ', model_path) self.model_path = model_path
self.model, self.h = load_model(model_path, device=self.device) self.model = None
self.h = load_config(model_path)
self.stft = STFT( self.stft = STFT(
self.h.sampling_rate, self.h.sampling_rate,
self.h.num_mels, self.h.num_mels,
@ -74,6 +75,9 @@ class NsfHifiGAN(torch.nn.Module):
return mel return mel
def forward(self, mel, f0): def forward(self, mel, f0):
if self.model is None:
print('| Load HifiGAN: ', self.model_path)
self.model, self.h = load_model(self.model_path, device=self.device)
with torch.no_grad(): with torch.no_grad():
c = mel.transpose(1, 2) c = mel.transpose(1, 2)
audio = self.model(c, f0) audio = self.model(c, f0)
@ -81,6 +85,9 @@ class NsfHifiGAN(torch.nn.Module):
class NsfHifiGANLog10(NsfHifiGAN): class NsfHifiGANLog10(NsfHifiGAN):
def forward(self, mel, f0): def forward(self, mel, f0):
if self.model is None:
print('| Load HifiGAN: ', self.model_path)
self.model, self.h = load_model(self.model_path, device=self.device)
with torch.no_grad(): with torch.no_grad():
c = 0.434294 * mel.transpose(1, 2) c = 0.434294 * mel.transpose(1, 2)
audio = self.model(c, f0) audio = self.model(c, f0)

View File

@ -7,6 +7,8 @@ from random import shuffle
import json import json
import wave import wave
import diffusion.logger.utils as du
config_template = json.load(open("configs_template/config_template.json")) config_template = json.load(open("configs_template/config_template.json"))
pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$') pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
@ -68,14 +70,25 @@ if __name__ == "__main__":
wavpath = fname wavpath = fname
f.write(wavpath + "\n") f.write(wavpath + "\n")
d_config_template = du.load_config("configs_template/diffusion_template.yaml")
d_config_template.model.n_spk = spk_id
d_config_template.data.encoder = args.speech_encoder
d_config_template.spk = spk_dict
config_template["spk"] = spk_dict config_template["spk"] = spk_dict
config_template["model"]["n_speakers"] = spk_id config_template["model"]["n_speakers"] = spk_id
config_template["model"]["speech_encoder"] = args.speech_encoder config_template["model"]["speech_encoder"] = args.speech_encoder
if args.speech_encoder == "vec768l12": if args.speech_encoder == "vec768l12":
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768 config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768
d_config_template.data.encoder_out_channels = 768
elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft': elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft':
config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 256 config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 256
d_config_template.data.encoder_out_channels = 256
print("Writing configs/config.json") print("Writing configs/config.json")
with open("configs/config.json", "w") as f: with open("configs/config.json", "w") as f:
json.dump(config_template, f, indent=2) json.dump(config_template, f, indent=2)
print("Writing configs/diffusion_template.yaml")
du.save_config("configs/diffusion.yaml",d_config_template)

View File

@ -3,6 +3,7 @@ import multiprocessing
import os import os
import argparse import argparse
from random import shuffle from random import shuffle
import random
import torch import torch
from glob import glob from glob import glob
@ -12,19 +13,28 @@ import json
import utils import utils
import logging import logging
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
import diffusion.logger.utils as du
from diffusion.vocoder import Vocoder
import librosa import librosa
import numpy as np import numpy as np
hps = utils.get_hparams_from_file("configs/config.json") hps = utils.get_hparams_from_file("configs/config.json")
dconfig = du.load_config("configs/diffusion.yaml")
sampling_rate = hps.data.sampling_rate sampling_rate = hps.data.sampling_rate
hop_length = hps.data.hop_length hop_length = hps.data.hop_length
speech_encoder = hps["model"]["speech_encoder"] speech_encoder = hps["model"]["speech_encoder"]
def process_one(filename, hmodel,f0p):
def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
# print(filename) # print(filename)
wav, sr = librosa.load(filename, sr=sampling_rate) wav, sr = librosa.load(filename, sr=sampling_rate)
audio_norm = torch.FloatTensor(wav)
audio_norm = audio_norm.unsqueeze(0)
soft_path = filename + ".soft.pt" soft_path = filename + ".soft.pt"
if not os.path.exists(soft_path): if not os.path.exists(soft_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -32,7 +42,7 @@ def process_one(filename, hmodel,f0p):
wav16k = torch.from_numpy(wav16k).to(device) wav16k = torch.from_numpy(wav16k).to(device)
c = hmodel.encoder(wav16k) c = hmodel.encoder(wav16k)
torch.save(c.cpu(), soft_path) torch.save(c.cpu(), soft_path)
f0_path = filename + ".f0.npy" f0_path = filename + ".f0.npy"
if not os.path.exists(f0_path): if not os.path.exists(f0_path):
f0_predictor = utils.get_f0_predictor(f0p,sampling_rate=sampling_rate, hop_length=hop_length,device=None,threshold=0.05) f0_predictor = utils.get_f0_predictor(f0p,sampling_rate=sampling_rate, hop_length=hop_length,device=None,threshold=0.05)
@ -40,24 +50,23 @@ def process_one(filename, hmodel,f0p):
wav wav
) )
np.save(f0_path, np.asanyarray((f0,uv),dtype=object)) np.save(f0_path, np.asanyarray((f0,uv),dtype=object))
spec_path = filename.replace(".wav", ".spec.pt") spec_path = filename.replace(".wav", ".spec.pt")
if not os.path.exists(spec_path): if not os.path.exists(spec_path):
# Process spectrogram # Process spectrogram
# The following code can't be replaced by torch.FloatTensor(wav) # The following code can't be replaced by torch.FloatTensor(wav)
# because load_wav_to_torch return a tensor that need to be normalized # because load_wav_to_torch return a tensor that need to be normalized
audio, sr = utils.load_wav_to_torch(filename)
if sr != hps.data.sampling_rate: if sr != hps.data.sampling_rate:
raise ValueError( raise ValueError(
"{} SR doesn't match target {} SR".format( "{} SR doesn't match target {} SR".format(
sr, hps.data.sampling_rate sr, hps.data.sampling_rate
) )
) )
audio_norm = audio / hps.data.max_wav_value #audio_norm = audio / hps.data.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch( spec = spectrogram_torch(
audio_norm, audio_norm,
hps.data.filter_length, hps.data.filter_length,
@ -69,14 +78,40 @@ def process_one(filename, hmodel,f0p):
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
torch.save(spec, spec_path) torch.save(spec, spec_path)
if diff:
volume_path = filename + ".vol.npy"
volume_extractor = utils.Volume_Extractor(hop_length)
if not os.path.exists(volume_path):
volume = volume_extractor.extract(audio_norm)
np.save(volume_path, volume.to('cpu').numpy())
mel_path = filename + ".mel.npy"
if not os.path.exists(mel_path) and mel_extractor is not None:
mel_t = mel_extractor.extract(audio_norm.to(device), sampling_rate)
mel = mel_t.squeeze().to('cpu').numpy()
np.save(mel_path, mel)
aug_mel_path = filename + ".aug_mel.npy"
aug_vol_path = filename + ".aug_vol.npy"
max_amp = float(torch.max(torch.abs(audio_norm))) + 1e-5
max_shift = min(1, np.log10(1/max_amp))
log10_vol_shift = random.uniform(-1, max_shift)
keyshift = random.uniform(-5, 5)
if mel_extractor is not None:
aug_mel_t = mel_extractor.extract(audio_norm * (10 ** log10_vol_shift), sampling_rate, keyshift = keyshift)
aug_mel = aug_mel_t.squeeze().to('cpu').numpy()
aug_vol = volume_extractor.extract(audio_norm * (10 ** log10_vol_shift))
if not os.path.exists(aug_mel_path):
np.save(aug_mel_path,np.asanyarray((aug_mel,keyshift),dtype=object))
if not os.path.exists(aug_vol_path):
np.save(aug_vol_path,aug_vol.to('cpu').numpy())
def process_batch(filenames,f0p):
def process_batch(filenames,f0p,diff=False,mel_extractor=None):
print("Loading hubert for content...") print("Loading hubert for content...")
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
hmodel = utils.get_speech_encoder(speech_encoder,device=device) hmodel = utils.get_speech_encoder(speech_encoder,device=device)
print("Loaded hubert.") print("Loaded hubert.")
for filename in tqdm(filenames): for filename in tqdm(filenames):
process_one(filename, hmodel,f0p) process_one(filename, hmodel,f0p,diff,mel_extractor)
if __name__ == "__main__": if __name__ == "__main__":
@ -85,19 +120,27 @@ if __name__ == "__main__":
"--in_dir", type=str, default="dataset/44k", help="path to input dir" "--in_dir", type=str, default="dataset/44k", help="path to input dir"
) )
parser.add_argument( parser.add_argument(
'--f0_predictor', type=str, default="dio", help='Select F0 predictor, can select crepe,pm,dio,harvest, default pm(note: crepe is original F0 using mean filter)' '--use_diff',action='store_true', help='Whether to use the diffusion model'
) )
parser.add_argument( parser.add_argument(
'--ues_diff',action='store_true', help='Whether to use the diffusion model' '--f0_predictor', type=str, default="dio", help='Select F0 predictor, can select crepe,pm,dio,harvest, default pm(note: crepe is original F0 using mean filter)'
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args() args = parser.parse_args()
f0p = args.f0_predictor f0p = args.f0_predictor
print(speech_encoder) print(speech_encoder)
print(f0p) print(f0p)
if args.use_diff:
print("use_diff")
print("Loading Mel Extractor...")
mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = device)
print("Loaded Mel Extractor.")
else:
mel_extractor = None
filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10] filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10]
shuffle(filenames) shuffle(filenames)
multiprocessing.set_start_method("spawn", force=True) multiprocessing.set_start_method("spawn", force=True)
num_processes = 1 num_processes = 1
chunk_size = int(math.ceil(len(filenames) / num_processes)) chunk_size = int(math.ceil(len(filenames) / num_processes))
chunks = [ chunks = [
@ -105,7 +148,7 @@ if __name__ == "__main__":
] ]
print([len(c) for c in chunks]) print([len(c) for c in chunks])
processes = [ processes = [
multiprocessing.Process(target=process_batch, args=(chunk,f0p)) for chunk in chunks multiprocessing.Process(target=process_batch, args=(chunk,f0p,args.use_diff,mel_extractor)) for chunk in chunks
] ]
for p in processes: for p in processes:
p.start() p.start()

View File

@ -2,7 +2,7 @@ import os
import argparse import argparse
import torch import torch
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
from logger import utils from diffusion.logger import utils
from diffusion.data_loaders import get_data_loaders from diffusion.data_loaders import get_data_loaders
from diffusion.solver import train from diffusion.solver import train
from diffusion.unit2mel import Unit2Mel from diffusion.unit2mel import Unit2Mel

View File

@ -410,3 +410,15 @@ class HParams():
def __repr__(self): def __repr__(self):
return self.__dict__.__repr__() return self.__dict__.__repr__()
class Volume_Extractor:
def __init__(self, hop_size = 512):
self.hop_size = hop_size
def extract(self, audio): # audio: 1d numpy array
n_frames = int(len(audio) // self.hop_size) + 1
audio2 = audio ** 2
audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
volume = torch.FloatTensor([torch.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
volume = torch.sqrt(volume)
return volume

View File

@ -13,12 +13,7 @@ LRELU_SLOPE = 0.1
def load_model(model_path, device='cuda'): def load_model(model_path, device='cuda'):
config_file = os.path.join(os.path.split(model_path)[0], 'config.json') h = load_config(model_path)
with open(config_file) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
generator = Generator(h).to(device) generator = Generator(h).to(device)
@ -29,6 +24,15 @@ def load_model(model_path, device='cuda'):
del cp_dict del cp_dict
return generator, h return generator, h
def load_config(model_path):
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
with open(config_file) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
return h
class ResBlock1(torch.nn.Module): class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):