2023-05-16 05:17:51 +00:00
|
|
|
'''
|
|
|
|
author: wayn391@mastertones
|
|
|
|
'''
|
|
|
|
|
2023-06-26 06:57:53 +00:00
|
|
|
import datetime
|
2023-05-16 05:17:51 +00:00
|
|
|
import os
|
|
|
|
import time
|
2023-06-26 06:57:53 +00:00
|
|
|
|
2023-05-16 05:17:51 +00:00
|
|
|
import matplotlib.pyplot as plt
|
2023-06-26 06:57:53 +00:00
|
|
|
import torch
|
|
|
|
import yaml
|
2023-05-16 05:17:51 +00:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
2023-06-26 06:57:53 +00:00
|
|
|
|
2023-05-16 05:17:51 +00:00
|
|
|
class Saver(object):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
args,
|
|
|
|
initial_global_step=-1):
|
|
|
|
|
|
|
|
self.expdir = args.env.expdir
|
|
|
|
self.sample_rate = args.data.sampling_rate
|
|
|
|
|
|
|
|
# cold start
|
|
|
|
self.global_step = initial_global_step
|
|
|
|
self.init_time = time.time()
|
|
|
|
self.last_time = time.time()
|
|
|
|
|
|
|
|
# makedirs
|
|
|
|
os.makedirs(self.expdir, exist_ok=True)
|
|
|
|
|
|
|
|
# path
|
|
|
|
self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
|
|
|
|
|
|
|
|
# ckpt
|
|
|
|
os.makedirs(self.expdir, exist_ok=True)
|
|
|
|
|
|
|
|
# writer
|
|
|
|
self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
|
|
|
|
|
|
|
|
# save config
|
|
|
|
path_config = os.path.join(self.expdir, 'config.yaml')
|
|
|
|
with open(path_config, "w") as out_config:
|
|
|
|
yaml.dump(dict(args), out_config)
|
|
|
|
|
|
|
|
|
|
|
|
def log_info(self, msg):
|
|
|
|
'''log method'''
|
|
|
|
if isinstance(msg, dict):
|
|
|
|
msg_list = []
|
|
|
|
for k, v in msg.items():
|
|
|
|
tmp_str = ''
|
|
|
|
if isinstance(v, int):
|
|
|
|
tmp_str = '{}: {:,}'.format(k, v)
|
|
|
|
else:
|
|
|
|
tmp_str = '{}: {}'.format(k, v)
|
|
|
|
|
|
|
|
msg_list.append(tmp_str)
|
|
|
|
msg_str = '\n'.join(msg_list)
|
|
|
|
else:
|
|
|
|
msg_str = msg
|
|
|
|
|
|
|
|
# dsplay
|
|
|
|
print(msg_str)
|
|
|
|
|
|
|
|
# save
|
|
|
|
with open(self.path_log_info, 'a') as fp:
|
|
|
|
fp.write(msg_str+'\n')
|
|
|
|
|
|
|
|
def log_value(self, dict):
|
|
|
|
for k, v in dict.items():
|
|
|
|
self.writer.add_scalar(k, v, self.global_step)
|
|
|
|
|
|
|
|
def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
|
|
|
|
spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
|
|
|
|
spec = spec_cat[0]
|
|
|
|
if isinstance(spec, torch.Tensor):
|
|
|
|
spec = spec.cpu().numpy()
|
|
|
|
fig = plt.figure(figsize=(12, 9))
|
|
|
|
plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
|
|
|
|
plt.tight_layout()
|
|
|
|
self.writer.add_figure(name, fig, self.global_step)
|
|
|
|
|
|
|
|
def log_audio(self, dict):
|
|
|
|
for k, v in dict.items():
|
|
|
|
self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
|
|
|
|
|
|
|
|
def get_interval_time(self, update=True):
|
|
|
|
cur_time = time.time()
|
|
|
|
time_interval = cur_time - self.last_time
|
|
|
|
if update:
|
|
|
|
self.last_time = cur_time
|
|
|
|
return time_interval
|
|
|
|
|
|
|
|
def get_total_time(self, to_str=True):
|
|
|
|
total_time = time.time() - self.init_time
|
|
|
|
if to_str:
|
|
|
|
total_time = str(datetime.timedelta(
|
|
|
|
seconds=total_time))[:-5]
|
|
|
|
return total_time
|
|
|
|
|
|
|
|
def save_model(
|
|
|
|
self,
|
|
|
|
model,
|
|
|
|
optimizer,
|
|
|
|
name='model',
|
|
|
|
postfix='',
|
|
|
|
to_json=False):
|
|
|
|
# path
|
|
|
|
if postfix:
|
|
|
|
postfix = '_' + postfix
|
|
|
|
path_pt = os.path.join(
|
|
|
|
self.expdir , name+postfix+'.pt')
|
|
|
|
|
|
|
|
# check
|
|
|
|
print(' [*] model checkpoint saved: {}'.format(path_pt))
|
|
|
|
|
|
|
|
# save
|
|
|
|
if optimizer is not None:
|
|
|
|
torch.save({
|
|
|
|
'global_step': self.global_step,
|
|
|
|
'model': model.state_dict(),
|
|
|
|
'optimizer': optimizer.state_dict()}, path_pt)
|
|
|
|
else:
|
|
|
|
torch.save({
|
|
|
|
'global_step': self.global_step,
|
|
|
|
'model': model.state_dict()}, path_pt)
|
2023-06-19 17:46:18 +00:00
|
|
|
|
2023-05-16 05:17:51 +00:00
|
|
|
|
|
|
|
def delete_model(self, name='model', postfix=''):
|
|
|
|
# path
|
|
|
|
if postfix:
|
|
|
|
postfix = '_' + postfix
|
|
|
|
path_pt = os.path.join(
|
|
|
|
self.expdir , name+postfix+'.pt')
|
|
|
|
|
|
|
|
# delete
|
|
|
|
if os.path.exists(path_pt):
|
|
|
|
os.remove(path_pt)
|
|
|
|
print(' [*] model checkpoint deleted: {}'.format(path_pt))
|
|
|
|
|
|
|
|
def global_step_increment(self):
|
|
|
|
self.global_step += 1
|
|
|
|
|
|
|
|
|