diff-svc/training/base_task.py

336 lines
11 KiB
Python

import logging
import os
import random
import shutil
import sys
import matplotlib
import numpy as np
import torch.distributed as dist
import torch.utils.data
from pytorch_lightning.loggers import TensorBoardLogger
from torch import nn
import utils
from utils.hparams import hparams, set_hparams
from utils.pl_utils import LatestModelCheckpoint, BaseTrainer, data_loader, DDP
matplotlib.use('Agg')
torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
class BaseTask(nn.Module):
'''
Base class for training tasks.
1. *load_ckpt*:
load checkpoint;
2. *training_step*:
record and log the loss;
3. *optimizer_step*:
run backwards step;
4. *start*:
load training configs, backup code, log to tensorboard, start training;
5. *configure_ddp* and *init_ddp_connection*:
start parallel training.
Subclasses should define:
1. *build_model*, *build_optimizer*, *build_scheduler*:
how to build the model, the optimizer and the training scheduler;
2. *_training_step*:
one training step of the model;
3. *validation_end* and *_validation_end*:
postprocess the validation output.
'''
def __init__(self, *args, **kwargs):
# dataset configs
super(BaseTask, self).__init__(*args, **kwargs)
self.current_epoch = 0
self.global_step = 0
self.loaded_optimizer_states_dict = {}
self.trainer = None
self.logger = None
self.on_gpu = False
self.use_dp = False
self.use_ddp = False
self.example_input_array = None
self.max_tokens = hparams['max_tokens']
self.max_sentences = hparams['max_sentences']
self.max_eval_tokens = hparams['max_eval_tokens']
if self.max_eval_tokens == -1:
hparams['max_eval_tokens'] = self.max_eval_tokens = self.max_tokens
self.max_eval_sentences = hparams['max_eval_sentences']
if self.max_eval_sentences == -1:
hparams['max_eval_sentences'] = self.max_eval_sentences = self.max_sentences
self.model = None
self.training_losses_meter = None
###########
# Training, validation and testing
###########
def build_model(self):
raise NotImplementedError
def load_ckpt(self, ckpt_base_dir, current_model_name=None, model_name='model', force=True, strict=True):
# This function is updated on 2021.12.13
if current_model_name is None:
current_model_name = model_name
utils.load_ckpt(self.__getattr__(current_model_name), ckpt_base_dir, current_model_name, force, strict)
def on_epoch_start(self):
self.training_losses_meter = {'total_loss': utils.AvgrageMeter()}
def _training_step(self, sample, batch_idx, optimizer_idx):
"""
:param sample:
:param batch_idx:
:return: total loss: torch.Tensor, loss_log: dict
"""
raise NotImplementedError
def training_step(self, sample, batch_idx, optimizer_idx=-1):
loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
self.opt_idx = optimizer_idx
if loss_ret is None:
return {'loss': None}
total_loss, log_outputs = loss_ret
log_outputs = utils.tensors_to_scalars(log_outputs)
for k, v in log_outputs.items():
if k not in self.training_losses_meter:
self.training_losses_meter[k] = utils.AvgrageMeter()
if not np.isnan(v):
self.training_losses_meter[k].update(v)
self.training_losses_meter['total_loss'].update(total_loss.item())
try:
log_outputs['lr'] = self.scheduler.get_lr()
if isinstance(log_outputs['lr'], list):
log_outputs['lr'] = log_outputs['lr'][0]
except:
pass
# log_outputs['all_loss'] = total_loss.item()
progress_bar_log = log_outputs
tb_log = {f'tr/{k}': v for k, v in log_outputs.items()}
return {
'loss': total_loss,
'progress_bar': progress_bar_log,
'log': tb_log
}
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
optimizer.step()
optimizer.zero_grad()
if self.scheduler is not None:
self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
def on_epoch_end(self):
loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()}
print(f"\n==============\n "
f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}"
f"\n==============\n")
def validation_step(self, sample, batch_idx):
"""
:param sample:
:param batch_idx:
:return: output: dict
"""
raise NotImplementedError
def _validation_end(self, outputs):
"""
:param outputs:
:return: loss_output: dict
"""
raise NotImplementedError
def validation_end(self, outputs):
loss_output = self._validation_end(outputs)
print(f"\n==============\n "
f"valid results: {loss_output}"
f"\n==============\n")
return {
'log': {f'val/{k}': v for k, v in loss_output.items()},
'val_loss': loss_output['total_loss']
}
def build_scheduler(self, optimizer):
raise NotImplementedError
def build_optimizer(self, model):
raise NotImplementedError
def configure_optimizers(self):
optm = self.build_optimizer(self.model)
self.scheduler = self.build_scheduler(optm)
return [optm]
def test_start(self):
pass
def test_step(self, sample, batch_idx):
return self.validation_step(sample, batch_idx)
def test_end(self, outputs):
return self.validation_end(outputs)
###########
# Running configuration
###########
@classmethod
def start(cls):
set_hparams()
os.environ['MASTER_PORT'] = str(random.randint(15000, 30000))
random.seed(hparams['seed'])
np.random.seed(hparams['seed'])
task = cls()
work_dir = hparams['work_dir']
trainer = BaseTrainer(checkpoint_callback=LatestModelCheckpoint(
filepath=work_dir,
verbose=True,
monitor='val_loss',
mode='min',
num_ckpt_keep=hparams['num_ckpt_keep'],
save_best=hparams['save_best'],
period=1 if hparams['save_ckpt'] else 100000
),
logger=TensorBoardLogger(
save_dir=work_dir,
name='lightning_logs',
version='lastest'
),
gradient_clip_val=hparams['clip_grad_norm'],
val_check_interval=hparams['val_check_interval'],
row_log_interval=hparams['log_interval'],
max_updates=hparams['max_updates'],
num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams[
'validate'] else 10000,
accumulate_grad_batches=hparams['accumulate_grad_batches'],
use_amp=hparams['use_amp'])
if not hparams['infer']: # train
# Copy spk_map.json to work dir
spk_map = os.path.join(work_dir, 'spk_map.json')
spk_map_orig = os.path.join(hparams['binary_data_dir'], 'spk_map.json')
if not os.path.exists(spk_map) and os.path.exists(spk_map_orig):
shutil.copy(spk_map_orig, spk_map)
print(f"| Copied spk map to {spk_map}.")
trainer.checkpoint_callback.task = task
trainer.fit(task)
else:
trainer.test(task)
@staticmethod
def configure_ddp(model, device_ids):
model = DDP(
model,
device_ids=device_ids,
find_unused_parameters=True
)
if dist.get_rank() != 0 and not hparams['debug']:
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
random.seed(hparams['seed'])
np.random.seed(hparams['seed'])
return model
@staticmethod
def training_end(self, *args, **kwargs):
return None
def init_ddp_connection(self, proc_rank, world_size):
set_hparams(print_hparams=False)
# guarantees unique ports across jobs from same grid search
default_port = 12910
# if user gave a port number, use that one instead
try:
default_port = os.environ['MASTER_PORT']
except Exception:
os.environ['MASTER_PORT'] = str(default_port)
# figure out the root node addr
root_node = '127.0.0.2'
root_node = self.trainer.resolve_root_node_address(root_node)
os.environ['MASTER_ADDR'] = root_node
dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
@data_loader
def train_dataloader(self):
return None
@data_loader
def test_dataloader(self):
return None
@data_loader
def val_dataloader(self):
return None
def on_load_checkpoint(self, checkpoint):
pass
def on_save_checkpoint(self, checkpoint):
pass
def on_sanity_check_start(self):
pass
def on_train_start(self):
pass
def on_train_end(self):
pass
def on_batch_start(self, batch):
pass
def on_batch_end(self):
pass
def on_pre_performance_check(self):
pass
def on_post_performance_check(self):
pass
def on_before_zero_grad(self, optimizer):
pass
def on_after_backward(self):
pass
@staticmethod
def backward(loss, optimizer):
loss.backward()
def grad_norm(self, norm_type):
results = {}
total_norm = 0
for name, p in self.named_parameters():
if p.requires_grad:
try:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm ** norm_type
norm = param_norm ** (1 / norm_type)
grad = round(norm.data.cpu().numpy().flatten()[0], 3)
results['grad_{}_norm_{}'.format(norm_type, name)] = grad
except Exception:
# this param had no grad
pass
total_norm = total_norm ** (1. / norm_type)
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
results['grad_{}_norm_total'.format(norm_type)] = grad
return results