1635 lines
58 KiB
Python
1635 lines
58 KiB
Python
import contextlib
|
|
import copy
|
|
import glob
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import threading
|
|
import traceback
|
|
from functools import wraps
|
|
|
|
import matplotlib
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
import torch.optim
|
|
import torch.utils.data
|
|
import tqdm
|
|
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from torch.cuda._utils import _get_device_index
|
|
from torch.nn import DataParallel
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
matplotlib.use('Agg')
|
|
|
|
|
|
def get_a_var(obj): # pragma: no cover
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj
|
|
|
|
if isinstance(obj, list) or isinstance(obj, tuple):
|
|
for result in map(get_a_var, obj):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
if isinstance(obj, dict):
|
|
for result in map(get_a_var, obj.items()):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
return None
|
|
|
|
|
|
def data_loader(fn):
|
|
"""
|
|
Decorator to make any fx with this use the lazy property
|
|
:param fn:
|
|
:return:
|
|
"""
|
|
|
|
wraps(fn)
|
|
attr_name = '_lazy_' + fn.__name__
|
|
|
|
def _get_data_loader(self):
|
|
try:
|
|
value = getattr(self, attr_name)
|
|
except AttributeError:
|
|
try:
|
|
value = fn(self) # Lazy evaluation, done only once.
|
|
if (
|
|
value is not None and
|
|
not isinstance(value, list) and
|
|
fn.__name__ in ['test_dataloader', 'val_dataloader']
|
|
):
|
|
value = [value]
|
|
except AttributeError as e:
|
|
# Guard against AttributeError suppression. (Issue #142)
|
|
traceback.print_exc()
|
|
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
|
|
raise RuntimeError(error) from e
|
|
setattr(self, attr_name, value) # Memoize evaluation.
|
|
return value
|
|
|
|
return _get_data_loader
|
|
|
|
|
|
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no cover
|
|
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
|
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
|
on each of :attr:`devices`.
|
|
|
|
Args:
|
|
modules (Module): modules to be parallelized
|
|
inputs (tensor): inputs to the modules
|
|
devices (list of int or torch.device): CUDA devices
|
|
|
|
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
|
:attr:`devices` (if given) should all have same length. Moreover, each
|
|
element of :attr:`inputs` can either be a single object as the only argument
|
|
to a module, or a collection of positional arguments.
|
|
"""
|
|
assert len(modules) == len(inputs)
|
|
if kwargs_tup is not None:
|
|
assert len(modules) == len(kwargs_tup)
|
|
else:
|
|
kwargs_tup = ({},) * len(modules)
|
|
if devices is not None:
|
|
assert len(modules) == len(devices)
|
|
else:
|
|
devices = [None] * len(modules)
|
|
devices = list(map(lambda x: _get_device_index(x, True), devices))
|
|
lock = threading.Lock()
|
|
results = {}
|
|
grad_enabled = torch.is_grad_enabled()
|
|
|
|
def _worker(i, module, input, kwargs, device=None):
|
|
torch.set_grad_enabled(grad_enabled)
|
|
if device is None:
|
|
device = get_a_var(input).get_device()
|
|
try:
|
|
with torch.cuda.device(device):
|
|
# this also avoids accidental slicing of `input` if it is a Tensor
|
|
if not isinstance(input, (list, tuple)):
|
|
input = (input,)
|
|
|
|
# ---------------
|
|
# CHANGE
|
|
if module.training:
|
|
output = module.training_step(*input, **kwargs)
|
|
|
|
elif module.testing:
|
|
output = module.test_step(*input, **kwargs)
|
|
|
|
else:
|
|
output = module.validation_step(*input, **kwargs)
|
|
# ---------------
|
|
|
|
with lock:
|
|
results[i] = output
|
|
except Exception as e:
|
|
with lock:
|
|
results[i] = e
|
|
|
|
# make sure each module knows what training state it's in...
|
|
# fixes weird bug where copies are out of sync
|
|
root_m = modules[0]
|
|
for m in modules[1:]:
|
|
m.training = root_m.training
|
|
m.testing = root_m.testing
|
|
|
|
if len(modules) > 1:
|
|
threads = [threading.Thread(target=_worker,
|
|
args=(i, module, input, kwargs, device))
|
|
for i, (module, input, kwargs, device) in
|
|
enumerate(zip(modules, inputs, kwargs_tup, devices))]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
else:
|
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
|
|
|
outputs = []
|
|
for i in range(len(inputs)):
|
|
output = results[i]
|
|
if isinstance(output, Exception):
|
|
raise output
|
|
outputs.append(output)
|
|
return outputs
|
|
|
|
|
|
def _find_tensors(obj): # pragma: no cover
|
|
r"""
|
|
Recursively find all tensors contained in the specified object.
|
|
"""
|
|
if isinstance(obj, torch.Tensor):
|
|
return [obj]
|
|
if isinstance(obj, (list, tuple)):
|
|
return itertools.chain(*map(_find_tensors, obj))
|
|
if isinstance(obj, dict):
|
|
return itertools.chain(*map(_find_tensors, obj.values()))
|
|
return []
|
|
|
|
|
|
class DDP(DistributedDataParallel):
|
|
"""
|
|
Override the forward call in lightning so it goes to training and validation step respectively
|
|
"""
|
|
|
|
def parallel_apply(self, replicas, inputs, kwargs):
|
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
|
|
|
def forward(self, *inputs, **kwargs): # pragma: no cover
|
|
self._sync_params()
|
|
if self.device_ids:
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
if len(self.device_ids) == 1:
|
|
# --------------
|
|
# LIGHTNING MOD
|
|
# --------------
|
|
# normal
|
|
# output = self.module(*inputs[0], **kwargs[0])
|
|
# lightning
|
|
if self.module.training:
|
|
output = self.module.training_step(*inputs[0], **kwargs[0])
|
|
elif self.module.testing:
|
|
output = self.module.test_step(*inputs[0], **kwargs[0])
|
|
else:
|
|
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
|
else:
|
|
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
|
|
output = self.gather(outputs, self.output_device)
|
|
else:
|
|
# normal
|
|
output = self.module(*inputs, **kwargs)
|
|
|
|
if torch.is_grad_enabled():
|
|
# We'll return the output object verbatim since it is a freeform
|
|
# object. We need to find any tensors in this object, though,
|
|
# because we need to figure out which parameters were used during
|
|
# this forward pass, to ensure we short circuit reduction for any
|
|
# unused parameters. Only if `find_unused_parameters` is set.
|
|
if self.find_unused_parameters:
|
|
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
|
else:
|
|
self.reducer.prepare_for_backward([])
|
|
return output
|
|
|
|
|
|
class DP(DataParallel):
|
|
"""
|
|
Override the forward call in lightning so it goes to training and validation step respectively
|
|
"""
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
if not self.device_ids:
|
|
return self.module(*inputs, **kwargs)
|
|
|
|
for t in itertools.chain(self.module.parameters(), self.module.buffers()):
|
|
if t.device != self.src_device_obj:
|
|
raise RuntimeError("module must have its parameters and buffers "
|
|
"on device {} (device_ids[0]) but found one of "
|
|
"them on device: {}".format(self.src_device_obj, t.device))
|
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
if len(self.device_ids) == 1:
|
|
# lightning
|
|
if self.module.training:
|
|
return self.module.training_step(*inputs[0], **kwargs[0])
|
|
elif self.module.testing:
|
|
return self.module.test_step(*inputs[0], **kwargs[0])
|
|
else:
|
|
return self.module.validation_step(*inputs[0], **kwargs[0])
|
|
|
|
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
|
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
|
return self.gather(outputs, self.output_device)
|
|
|
|
def parallel_apply(self, replicas, inputs, kwargs):
|
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
|
|
|
|
|
class GradientAccumulationScheduler:
|
|
def __init__(self, scheduling: dict):
|
|
if scheduling == {}: # empty dict error
|
|
raise TypeError("Empty dict cannot be interpreted correct")
|
|
|
|
for key in scheduling.keys():
|
|
if not isinstance(key, int) or not isinstance(scheduling[key], int):
|
|
raise TypeError("All epoches and accumulation factor must be integers")
|
|
|
|
minimal_epoch = min(scheduling.keys())
|
|
if minimal_epoch < 1:
|
|
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
|
|
raise IndexError(msg)
|
|
elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor
|
|
scheduling.update({1: 1})
|
|
|
|
self.scheduling = scheduling
|
|
self.epochs = sorted(scheduling.keys())
|
|
|
|
def on_epoch_begin(self, epoch, trainer):
|
|
epoch += 1 # indexing epochs from 1
|
|
for i in reversed(range(len(self.epochs))):
|
|
if epoch >= self.epochs[i]:
|
|
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
|
|
break
|
|
|
|
|
|
class LatestModelCheckpoint(ModelCheckpoint):
|
|
def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5,
|
|
save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True):
|
|
super(ModelCheckpoint, self).__init__()
|
|
self.monitor = monitor
|
|
self.verbose = verbose
|
|
self.filepath = filepath
|
|
os.makedirs(filepath, exist_ok=True)
|
|
self.num_ckpt_keep = num_ckpt_keep
|
|
self.save_best = save_best
|
|
self.save_weights_only = save_weights_only
|
|
self.period = period
|
|
self.epochs_since_last_check = 0
|
|
self.prefix = prefix
|
|
self.best_k_models = {}
|
|
# {filename: monitor}
|
|
self.kth_best_model = ''
|
|
self.save_top_k = 1
|
|
self.task = None
|
|
if mode == 'min':
|
|
self.monitor_op = np.less
|
|
self.best = np.Inf
|
|
self.mode = 'min'
|
|
elif mode == 'max':
|
|
self.monitor_op = np.greater
|
|
self.best = -np.Inf
|
|
self.mode = 'max'
|
|
else:
|
|
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
|
|
self.monitor_op = np.greater
|
|
self.best = -np.Inf
|
|
self.mode = 'max'
|
|
else:
|
|
self.monitor_op = np.less
|
|
self.best = np.Inf
|
|
self.mode = 'min'
|
|
if os.path.exists(f'{self.filepath}/best_valid.npy'):
|
|
self.best = np.load(f'{self.filepath}/best_valid.npy')[0]
|
|
|
|
def get_all_ckpts(self):
|
|
return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'),
|
|
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
logs = logs or {}
|
|
self.epochs_since_last_check += 1
|
|
best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt'
|
|
if self.epochs_since_last_check >= self.period:
|
|
self.epochs_since_last_check = 0
|
|
filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt'
|
|
if self.verbose > 0:
|
|
logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}')
|
|
self._save_model(filepath)
|
|
for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]:
|
|
# TODO: test filesystem calls
|
|
os.remove(old_ckpt)
|
|
# subprocess.check_call(f'del "{old_ckpt}"', shell=True)
|
|
if self.verbose > 0:
|
|
logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}')
|
|
current = logs.get(self.monitor)
|
|
if current is not None and self.save_best:
|
|
if self.monitor_op(current, self.best):
|
|
self.best = current
|
|
if self.verbose > 0:
|
|
logging.info(
|
|
f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached'
|
|
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
|
|
f' {best_filepath} as top 1')
|
|
self._save_model(best_filepath)
|
|
np.save(f'{self.filepath}/best_valid.npy', [self.best])
|
|
|
|
def _save_model(self, path):
|
|
return self.save_function(path)
|
|
|
|
|
|
class BaseTrainer:
|
|
def __init__(
|
|
self,
|
|
logger=True,
|
|
checkpoint_callback=True,
|
|
default_save_path=None,
|
|
gradient_clip_val=0,
|
|
process_position=0,
|
|
gpus=-1,
|
|
log_gpu_memory=None,
|
|
show_progress_bar=True,
|
|
track_grad_norm=-1,
|
|
check_val_every_n_epoch=1,
|
|
accumulate_grad_batches=1,
|
|
max_updates=1000,
|
|
min_epochs=1,
|
|
val_check_interval=1.0,
|
|
log_save_interval=100,
|
|
row_log_interval=10,
|
|
print_nan_grads=False,
|
|
weights_summary='full',
|
|
num_sanity_val_steps=5,
|
|
resume_from_checkpoint=None,
|
|
use_amp=False
|
|
):
|
|
self.log_gpu_memory = log_gpu_memory
|
|
self.gradient_clip_val = gradient_clip_val
|
|
self.check_val_every_n_epoch = check_val_every_n_epoch
|
|
self.track_grad_norm = track_grad_norm
|
|
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
|
|
self.process_position = process_position
|
|
self.weights_summary = weights_summary
|
|
self.max_updates = max_updates
|
|
self.min_epochs = min_epochs
|
|
self.num_sanity_val_steps = num_sanity_val_steps
|
|
self.print_nan_grads = print_nan_grads
|
|
self.resume_from_checkpoint = resume_from_checkpoint
|
|
self.default_save_path = default_save_path
|
|
|
|
# training bookeeping
|
|
self.total_batch_idx = 0
|
|
self.running_loss = []
|
|
self.avg_loss = 0
|
|
self.batch_idx = 0
|
|
self.tqdm_metrics = {}
|
|
self.callback_metrics = {}
|
|
self.num_val_batches = 0
|
|
self.num_training_batches = 0
|
|
self.num_test_batches = 0
|
|
self.get_train_dataloader = None
|
|
self.get_test_dataloaders = None
|
|
self.get_val_dataloaders = None
|
|
self.is_iterable_train_dataloader = False
|
|
|
|
# training state
|
|
self.model = None
|
|
self.testing = False
|
|
self.disable_validation = False
|
|
self.lr_schedulers = []
|
|
self.optimizers = None
|
|
self.global_step = 0
|
|
self.current_epoch = 0
|
|
self.total_batches = 0
|
|
|
|
# configure checkpoint callback
|
|
self.checkpoint_callback = checkpoint_callback
|
|
self.checkpoint_callback.save_function = self.save_checkpoint
|
|
self.weights_save_path = self.checkpoint_callback.filepath
|
|
|
|
# accumulated grads
|
|
self.configure_accumulated_gradients(accumulate_grad_batches)
|
|
|
|
# allow int, string and gpu list
|
|
self.data_parallel_device_ids = [
|
|
int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != '']
|
|
if len(self.data_parallel_device_ids) == 0:
|
|
self.root_gpu = None
|
|
self.on_gpu = False
|
|
else:
|
|
self.root_gpu = self.data_parallel_device_ids[0]
|
|
self.on_gpu = True
|
|
|
|
# distributed backend choice
|
|
self.use_ddp = False
|
|
self.use_dp = False
|
|
self.single_gpu = False
|
|
self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp'
|
|
self.set_distributed_mode(self.distributed_backend)
|
|
|
|
self.proc_rank = 0
|
|
self.world_size = 1
|
|
self.node_rank = 0
|
|
|
|
# can't init progress bar here because starting a new process
|
|
# means the progress_bar won't survive pickling
|
|
self.show_progress_bar = show_progress_bar
|
|
|
|
# logging
|
|
self.log_save_interval = log_save_interval
|
|
self.val_check_interval = val_check_interval
|
|
self.logger = logger
|
|
self.logger.rank = 0
|
|
self.row_log_interval = row_log_interval
|
|
self.scaler = None
|
|
self.use_amp = use_amp
|
|
if self.use_amp:
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
|
|
@property
|
|
def num_gpus(self):
|
|
gpus = self.data_parallel_device_ids
|
|
if gpus is None:
|
|
return 0
|
|
else:
|
|
return len(gpus)
|
|
|
|
@property
|
|
def data_parallel(self):
|
|
return self.use_dp or self.use_ddp
|
|
|
|
def get_model(self):
|
|
is_dp_module = isinstance(self.model, (DDP, DP))
|
|
model = self.model.module if is_dp_module else self.model
|
|
return model
|
|
|
|
# -----------------------------
|
|
# MODEL TRAINING
|
|
# -----------------------------
|
|
def fit(self, model):
|
|
if self.use_ddp:
|
|
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
|
|
else:
|
|
model.svc_model = model.build_model()
|
|
if not self.testing:
|
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
|
if self.use_dp:
|
|
model.cuda(self.root_gpu)
|
|
model = DP(model, device_ids=self.data_parallel_device_ids)
|
|
elif self.single_gpu:
|
|
model.cuda(self.root_gpu)
|
|
self.run_pretrain_routine(model)
|
|
return 1
|
|
|
|
def init_optimizers(self, optimizers):
|
|
|
|
# single optimizer
|
|
if isinstance(optimizers, Optimizer):
|
|
return [optimizers], []
|
|
|
|
# two lists
|
|
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
|
|
optimizers, lr_schedulers = optimizers
|
|
return optimizers, lr_schedulers
|
|
|
|
# single list or tuple
|
|
elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
|
|
return optimizers, []
|
|
|
|
def run_pretrain_routine(self, model):
|
|
"""Sanity check a few things before starting actual training.
|
|
|
|
:param model:
|
|
"""
|
|
ref_model = model
|
|
if self.data_parallel:
|
|
ref_model = model.module
|
|
|
|
# give model convenience properties
|
|
ref_model.trainer = self
|
|
|
|
# set local properties on the model
|
|
self.copy_trainer_model_properties(ref_model)
|
|
|
|
# link up experiment object
|
|
if self.logger is not None:
|
|
ref_model.logger = self.logger
|
|
self.logger.save()
|
|
|
|
if self.use_ddp:
|
|
dist.barrier()
|
|
|
|
# set up checkpoint callback
|
|
# self.configure_checkpoint_callback()
|
|
|
|
# transfer data loaders from model
|
|
self.get_dataloaders(ref_model)
|
|
|
|
# track model now.
|
|
# if cluster resets state, the model will update with the saved weights
|
|
self.model = model
|
|
|
|
# restore training and model before hpc call
|
|
self.restore_weights(model)
|
|
|
|
# when testing requested only run test and return
|
|
if self.testing:
|
|
self.run_evaluation(test=True)
|
|
return
|
|
|
|
# check if we should run validation during training
|
|
self.disable_validation = self.num_val_batches == 0
|
|
|
|
# run tiny validation (if validation defined)
|
|
# to make sure program won't crash during val
|
|
ref_model.on_sanity_check_start()
|
|
ref_model.on_train_start()
|
|
if not self.disable_validation and self.num_sanity_val_steps > 0:
|
|
# init progress bars for validation sanity check
|
|
pbar = tqdm.tqdm(desc='Validation sanity check',
|
|
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
|
|
leave=False, position=2 * self.process_position,
|
|
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
|
|
self.main_progress_bar = pbar
|
|
# dummy validation progress bar
|
|
self.val_progress_bar = tqdm.tqdm(disable=True)
|
|
|
|
self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
|
|
|
|
# close progress bars
|
|
self.main_progress_bar.close()
|
|
self.val_progress_bar.close()
|
|
|
|
# init progress bar
|
|
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
|
|
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
|
|
file=sys.stdout)
|
|
self.main_progress_bar = pbar
|
|
|
|
# clear cache before training
|
|
if self.on_gpu:
|
|
torch.cuda.empty_cache()
|
|
|
|
# CORE TRAINING LOOP
|
|
self.train()
|
|
|
|
def test(self, model):
|
|
self.testing = True
|
|
self.fit(model)
|
|
|
|
@property
|
|
def training_tqdm_dict(self):
|
|
tqdm_dict = {
|
|
'step': '{}'.format(self.global_step),
|
|
}
|
|
tqdm_dict.update(self.tqdm_metrics)
|
|
return tqdm_dict
|
|
|
|
# --------------------
|
|
# restore ckpt
|
|
# --------------------
|
|
def restore_weights(self, model):
|
|
"""
|
|
To restore weights we have two cases.
|
|
First, attempt to restore hpc weights. If successful, don't restore
|
|
other weights.
|
|
|
|
Otherwise, try to restore actual weights
|
|
:param model:
|
|
:return:
|
|
"""
|
|
# clear cache before restore
|
|
if self.on_gpu:
|
|
torch.cuda.empty_cache()
|
|
|
|
if self.resume_from_checkpoint is not None:
|
|
self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)
|
|
else:
|
|
# restore weights if same exp version
|
|
self.restore_state_if_checkpoint_exists(model)
|
|
|
|
# wait for all models to restore weights
|
|
if self.use_ddp:
|
|
# wait for all processes to catch up
|
|
dist.barrier()
|
|
|
|
# clear cache after restore
|
|
if self.on_gpu:
|
|
torch.cuda.empty_cache()
|
|
|
|
def restore_state_if_checkpoint_exists(self, model):
|
|
did_restore = False
|
|
|
|
# do nothing if there's not dir or callback
|
|
no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback)
|
|
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
|
|
return did_restore
|
|
|
|
# restore trainer state and model if there is a weight for this experiment
|
|
last_steps = -1
|
|
last_ckpt_name = None
|
|
|
|
# find last epoch
|
|
checkpoints = os.listdir(self.checkpoint_callback.filepath)
|
|
for name in checkpoints:
|
|
if '.ckpt' in name and not name.endswith('part'):
|
|
if 'steps_' in name:
|
|
steps = name.split('steps_')[1]
|
|
steps = int(re.sub('[^0-9]', '', steps))
|
|
|
|
if steps > last_steps:
|
|
last_steps = steps
|
|
last_ckpt_name = name
|
|
|
|
# restore last checkpoint
|
|
if last_ckpt_name is not None:
|
|
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
|
|
self.restore(last_ckpt_path, self.on_gpu)
|
|
logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}')
|
|
did_restore = True
|
|
|
|
return did_restore
|
|
|
|
def restore(self, checkpoint_path, on_gpu):
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
# load model state
|
|
model = self.get_model()
|
|
|
|
# load the state_dict on the model automatically
|
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
|
if on_gpu:
|
|
model.cuda(self.root_gpu)
|
|
# load training state (affects trainer only)
|
|
self.restore_training_state(checkpoint)
|
|
model.global_step = self.global_step
|
|
del checkpoint
|
|
|
|
try:
|
|
if dist.is_initialized() and dist.get_rank() > 0:
|
|
return
|
|
except Exception as e:
|
|
print(e)
|
|
return
|
|
|
|
def restore_training_state(self, checkpoint):
|
|
"""
|
|
Restore trainer state.
|
|
Model will get its change to update
|
|
:param checkpoint:
|
|
:return:
|
|
"""
|
|
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
|
|
# return allowing checkpoints with meta information (global_step, etc)
|
|
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
|
|
|
|
self.global_step = checkpoint['global_step']
|
|
self.current_epoch = checkpoint['epoch']
|
|
|
|
if self.testing:
|
|
return
|
|
|
|
# restore the optimizers
|
|
optimizer_states = checkpoint['optimizer_states']
|
|
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
|
|
if optimizer is None:
|
|
return
|
|
optimizer.load_state_dict(opt_state)
|
|
|
|
# move optimizer to GPU 1 weight at a time
|
|
# avoids OOM
|
|
if self.root_gpu is not None:
|
|
for state in optimizer.state.values():
|
|
for k, v in state.items():
|
|
if isinstance(v, torch.Tensor):
|
|
state[k] = v.cuda(self.root_gpu)
|
|
|
|
# restore the lr schedulers
|
|
lr_schedulers = checkpoint['lr_schedulers']
|
|
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
|
|
scheduler.load_state_dict(lrs_state)
|
|
|
|
# --------------------
|
|
# MODEL SAVE CHECKPOINT
|
|
# --------------------
|
|
def _atomic_save(self, checkpoint, filepath):
|
|
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
|
|
|
|
This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once
|
|
saving is finished.
|
|
|
|
Args:
|
|
checkpoint (object): The object to save.
|
|
Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
|
|
accepts.
|
|
filepath (str|pathlib.Path): The path to which the checkpoint will be saved.
|
|
This points to the file that the checkpoint will be stored in.
|
|
"""
|
|
tmp_path = str(filepath) + ".part"
|
|
torch.save(checkpoint, tmp_path)
|
|
os.replace(tmp_path, filepath)
|
|
|
|
def save_checkpoint(self, filepath):
|
|
checkpoint = self.dump_checkpoint()
|
|
self._atomic_save(checkpoint, filepath)
|
|
|
|
def dump_checkpoint(self):
|
|
|
|
checkpoint = {
|
|
'epoch': self.current_epoch,
|
|
'global_step': self.global_step
|
|
}
|
|
|
|
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
|
|
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
|
|
|
|
# save optimizers
|
|
optimizer_states = []
|
|
for i, optimizer in enumerate(self.optimizers):
|
|
if optimizer is not None:
|
|
optimizer_states.append(optimizer.state_dict())
|
|
|
|
checkpoint['optimizer_states'] = optimizer_states
|
|
|
|
# save lr schedulers
|
|
lr_schedulers = []
|
|
for i, scheduler in enumerate(self.lr_schedulers):
|
|
lr_schedulers.append(scheduler.state_dict())
|
|
|
|
checkpoint['lr_schedulers'] = lr_schedulers
|
|
|
|
# add the hparams and state_dict from the model
|
|
model = self.get_model()
|
|
checkpoint['state_dict'] = model.state_dict()
|
|
# give the model a chance to add a few things
|
|
model.on_save_checkpoint(checkpoint)
|
|
|
|
return checkpoint
|
|
|
|
def copy_trainer_model_properties(self, model):
|
|
if isinstance(model, DP):
|
|
ref_model = model.module
|
|
elif isinstance(model, DDP):
|
|
ref_model = model.module
|
|
else:
|
|
ref_model = model
|
|
|
|
for m in [model, ref_model]:
|
|
m.trainer = self
|
|
m.on_gpu = self.on_gpu
|
|
m.use_dp = self.use_dp
|
|
m.use_ddp = self.use_ddp
|
|
m.testing = self.testing
|
|
m.single_gpu = self.single_gpu
|
|
|
|
def transfer_batch_to_gpu(self, batch, gpu_id):
|
|
# base case: object can be directly moved using `cuda` or `to`
|
|
if callable(getattr(batch, 'cuda', None)):
|
|
return batch.cuda(gpu_id, non_blocking=True)
|
|
|
|
elif callable(getattr(batch, 'to', None)):
|
|
return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
|
|
|
|
# when list
|
|
elif isinstance(batch, list):
|
|
for i, x in enumerate(batch):
|
|
batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
|
|
return batch
|
|
|
|
# when tuple
|
|
elif isinstance(batch, tuple):
|
|
batch = list(batch)
|
|
for i, x in enumerate(batch):
|
|
batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
|
|
return tuple(batch)
|
|
|
|
# when dict
|
|
elif isinstance(batch, dict):
|
|
for k, v in batch.items():
|
|
batch[k] = self.transfer_batch_to_gpu(v, gpu_id)
|
|
|
|
return batch
|
|
|
|
# nothing matches, return the value as is without transform
|
|
return batch
|
|
|
|
def set_distributed_mode(self, distributed_backend):
|
|
# skip for CPU
|
|
if self.num_gpus == 0:
|
|
return
|
|
|
|
# single GPU case
|
|
# in single gpu case we allow ddp so we can train on multiple
|
|
# nodes, 1 gpu per node
|
|
elif self.num_gpus == 1:
|
|
self.single_gpu = True
|
|
self.use_dp = False
|
|
self.use_ddp = False
|
|
self.root_gpu = 0
|
|
self.data_parallel_device_ids = [0]
|
|
else:
|
|
if distributed_backend is not None:
|
|
self.use_dp = distributed_backend == 'dp'
|
|
self.use_ddp = distributed_backend == 'ddp'
|
|
elif distributed_backend is None:
|
|
self.use_dp = True
|
|
self.use_ddp = False
|
|
|
|
logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}')
|
|
|
|
def ddp_train(self, gpu_idx, model):
|
|
"""
|
|
Entry point into a DP thread
|
|
:param gpu_idx:
|
|
:param model:
|
|
:param cluster_obj:
|
|
:return:
|
|
"""
|
|
# otherwise default to node rank 0
|
|
self.node_rank = 0
|
|
|
|
# show progressbar only on progress_rank 0
|
|
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0
|
|
|
|
# determine which process we are and world size
|
|
if self.use_ddp:
|
|
self.proc_rank = self.node_rank * self.num_gpus + gpu_idx
|
|
self.world_size = self.num_gpus
|
|
|
|
# let the exp know the rank to avoid overwriting logs
|
|
if self.logger is not None:
|
|
self.logger.rank = self.proc_rank
|
|
|
|
# set up server using proc 0's ip address
|
|
# try to init for 20 times at max in case ports are taken
|
|
# where to store ip_table
|
|
model.trainer = self
|
|
model.init_ddp_connection(self.proc_rank, self.world_size)
|
|
|
|
# CHOOSE OPTIMIZER
|
|
# allow for lr schedulers as well
|
|
model.svc_model = model.build_model()
|
|
if not self.testing:
|
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
|
|
|
# MODEL
|
|
# copy model to each gpu
|
|
if self.distributed_backend == 'ddp':
|
|
torch.cuda.set_device(gpu_idx)
|
|
model.cuda(gpu_idx)
|
|
|
|
# set model properties before going into wrapper
|
|
self.copy_trainer_model_properties(model)
|
|
|
|
# override root GPU
|
|
self.root_gpu = gpu_idx
|
|
|
|
if self.distributed_backend == 'ddp':
|
|
device_ids = [gpu_idx]
|
|
else:
|
|
device_ids = None
|
|
|
|
# allow user to configure ddp
|
|
model = model.configure_ddp(model, device_ids)
|
|
|
|
# continue training routine
|
|
self.run_pretrain_routine(model)
|
|
|
|
def resolve_root_node_address(self, root_node):
|
|
if '[' in root_node:
|
|
name = root_node.split('[')[0]
|
|
number = root_node.split(',')[0]
|
|
if '-' in number:
|
|
number = number.split('-')[0]
|
|
|
|
number = re.sub('[^0-9]', '', number)
|
|
root_node = name + number
|
|
|
|
return root_node
|
|
|
|
def log_metrics(self, metrics, grad_norm_dic, step=None):
|
|
"""Logs the metric dict passed in.
|
|
|
|
:param metrics:
|
|
:param grad_norm_dic:
|
|
"""
|
|
# added metrics by Lightning for convenience
|
|
metrics['epoch'] = self.current_epoch
|
|
|
|
# add norms
|
|
metrics.update(grad_norm_dic)
|
|
|
|
# turn all tensors to scalars
|
|
scalar_metrics = self.metrics_to_scalars(metrics)
|
|
|
|
step = step if step is not None else self.global_step
|
|
# log actual metrics
|
|
if self.proc_rank == 0 and self.logger is not None:
|
|
self.logger.log_metrics(scalar_metrics, step=step)
|
|
self.logger.save()
|
|
|
|
def add_tqdm_metrics(self, metrics):
|
|
for k, v in metrics.items():
|
|
if type(v) is torch.Tensor:
|
|
v = v.item()
|
|
|
|
self.tqdm_metrics[k] = v
|
|
|
|
def metrics_to_scalars(self, metrics):
|
|
new_metrics = {}
|
|
for k, v in metrics.items():
|
|
if isinstance(v, torch.Tensor):
|
|
v = v.item()
|
|
|
|
if type(v) is dict:
|
|
v = self.metrics_to_scalars(v)
|
|
|
|
new_metrics[k] = v
|
|
|
|
return new_metrics
|
|
|
|
def process_output(self, output, train=False):
|
|
"""Reduces output according to the training mode.
|
|
|
|
Separates loss from logging and tqdm metrics
|
|
:param output:
|
|
:return:
|
|
"""
|
|
# ---------------
|
|
# EXTRACT CALLBACK KEYS
|
|
# ---------------
|
|
# all keys not progress_bar or log are candidates for callbacks
|
|
callback_metrics = {}
|
|
for k, v in output.items():
|
|
if k not in ['progress_bar', 'log', 'hiddens']:
|
|
callback_metrics[k] = v
|
|
|
|
if train and self.use_dp:
|
|
num_gpus = self.num_gpus
|
|
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
|
|
|
|
for k, v in callback_metrics.items():
|
|
if isinstance(v, torch.Tensor):
|
|
callback_metrics[k] = v.item()
|
|
|
|
# ---------------
|
|
# EXTRACT PROGRESS BAR KEYS
|
|
# ---------------
|
|
try:
|
|
progress_output = output['progress_bar']
|
|
|
|
# reduce progress metrics for tqdm when using dp
|
|
if train and self.use_dp:
|
|
num_gpus = self.num_gpus
|
|
progress_output = self.reduce_distributed_output(progress_output, num_gpus)
|
|
|
|
progress_bar_metrics = progress_output
|
|
except Exception:
|
|
progress_bar_metrics = {}
|
|
|
|
# ---------------
|
|
# EXTRACT LOGGING KEYS
|
|
# ---------------
|
|
# extract metrics to log to experiment
|
|
try:
|
|
log_output = output['log']
|
|
|
|
# reduce progress metrics for tqdm when using dp
|
|
if train and self.use_dp:
|
|
num_gpus = self.num_gpus
|
|
log_output = self.reduce_distributed_output(log_output, num_gpus)
|
|
|
|
log_metrics = log_output
|
|
except Exception:
|
|
log_metrics = {}
|
|
|
|
# ---------------
|
|
# EXTRACT LOSS
|
|
# ---------------
|
|
# if output dict doesn't have the keyword loss
|
|
# then assume the output=loss if scalar
|
|
loss = None
|
|
if train:
|
|
try:
|
|
loss = output['loss']
|
|
except Exception:
|
|
if type(output) is torch.Tensor:
|
|
loss = output
|
|
else:
|
|
raise RuntimeError(
|
|
'No `loss` value in the dictionary returned from `model.training_step()`.'
|
|
)
|
|
|
|
# when using dp need to reduce the loss
|
|
if self.use_dp:
|
|
loss = self.reduce_distributed_output(loss, self.num_gpus)
|
|
|
|
# ---------------
|
|
# EXTRACT HIDDEN
|
|
# ---------------
|
|
hiddens = output.get('hiddens')
|
|
|
|
# use every metric passed in as a candidate for callback
|
|
callback_metrics.update(progress_bar_metrics)
|
|
callback_metrics.update(log_metrics)
|
|
|
|
# convert tensors to numpy
|
|
for k, v in callback_metrics.items():
|
|
if isinstance(v, torch.Tensor):
|
|
callback_metrics[k] = v.item()
|
|
|
|
return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
|
|
|
|
def reduce_distributed_output(self, output, num_gpus):
|
|
if num_gpus <= 1:
|
|
return output
|
|
|
|
# when using DP, we get one output per gpu
|
|
# average outputs and return
|
|
if type(output) is torch.Tensor:
|
|
return output.mean()
|
|
|
|
for k, v in output.items():
|
|
# recurse on nested dics
|
|
if isinstance(output[k], dict):
|
|
output[k] = self.reduce_distributed_output(output[k], num_gpus)
|
|
|
|
# do nothing when there's a scalar
|
|
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
|
|
pass
|
|
|
|
# reduce only metrics that have the same number of gpus
|
|
elif output[k].size(0) == num_gpus:
|
|
reduced = torch.mean(output[k])
|
|
output[k] = reduced
|
|
return output
|
|
|
|
def clip_gradients(self):
|
|
if self.gradient_clip_val > 0:
|
|
model = self.get_model()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
|
|
|
|
def print_nan_gradients(self):
|
|
model = self.get_model()
|
|
for param in model.parameters():
|
|
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
|
|
logging.info(param, param.grad)
|
|
|
|
def configure_accumulated_gradients(self, accumulate_grad_batches):
|
|
self.accumulate_grad_batches = None
|
|
|
|
if isinstance(accumulate_grad_batches, dict):
|
|
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
|
elif isinstance(accumulate_grad_batches, int):
|
|
schedule = {1: accumulate_grad_batches}
|
|
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
|
|
else:
|
|
raise TypeError("Gradient accumulation supports only int and dict types")
|
|
|
|
def get_dataloaders(self, model):
|
|
if not self.testing:
|
|
self.init_train_dataloader(model)
|
|
self.init_val_dataloader(model)
|
|
else:
|
|
self.init_test_dataloader(model)
|
|
|
|
if self.use_ddp:
|
|
dist.barrier()
|
|
if not self.testing:
|
|
self.get_train_dataloader()
|
|
self.get_val_dataloaders()
|
|
else:
|
|
self.get_test_dataloaders()
|
|
|
|
def init_train_dataloader(self, model):
|
|
self.fisrt_epoch = True
|
|
self.get_train_dataloader = model.train_dataloader
|
|
if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader):
|
|
self.num_training_batches = len(self.get_train_dataloader())
|
|
self.num_training_batches = int(self.num_training_batches)
|
|
else:
|
|
self.num_training_batches = float('inf')
|
|
self.is_iterable_train_dataloader = True
|
|
if isinstance(self.val_check_interval, int):
|
|
self.val_check_batch = self.val_check_interval
|
|
else:
|
|
self._percent_range_check('val_check_interval')
|
|
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
|
self.val_check_batch = max(1, self.val_check_batch)
|
|
|
|
def init_val_dataloader(self, model):
|
|
self.get_val_dataloaders = model.val_dataloader
|
|
self.num_val_batches = 0
|
|
if self.get_val_dataloaders() is not None:
|
|
if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader):
|
|
self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
|
|
self.num_val_batches = int(self.num_val_batches)
|
|
else:
|
|
self.num_val_batches = float('inf')
|
|
|
|
def init_test_dataloader(self, model):
|
|
self.get_test_dataloaders = model.test_dataloader
|
|
if self.get_test_dataloaders() is not None:
|
|
if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader):
|
|
self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
|
|
self.num_test_batches = int(self.num_test_batches)
|
|
else:
|
|
self.num_test_batches = float('inf')
|
|
|
|
def evaluate(self, model, dataloaders, max_batches, test=False):
|
|
"""Run evaluation code.
|
|
|
|
:param model: PT model
|
|
:param dataloaders: list of PT dataloaders
|
|
:param max_batches: Scalar
|
|
:param test: boolean
|
|
:return:
|
|
"""
|
|
# enable eval mode
|
|
model.zero_grad()
|
|
model.eval()
|
|
|
|
# copy properties for forward overrides
|
|
self.copy_trainer_model_properties(model)
|
|
|
|
# disable gradients to save memory
|
|
torch.set_grad_enabled(False)
|
|
|
|
if test:
|
|
self.get_model().test_start()
|
|
# bookkeeping
|
|
outputs = []
|
|
|
|
# run training
|
|
for dataloader_idx, dataloader in enumerate(dataloaders):
|
|
dl_outputs = []
|
|
for batch_idx, batch in enumerate(dataloader):
|
|
|
|
if batch is None: # pragma: no cover
|
|
continue
|
|
|
|
# stop short when on fast_dev_run (sets max_batch=1)
|
|
if batch_idx >= max_batches:
|
|
break
|
|
|
|
# -----------------
|
|
# RUN EVALUATION STEP
|
|
# -----------------
|
|
output = self.evaluation_forward(model,
|
|
batch,
|
|
batch_idx,
|
|
dataloader_idx,
|
|
test)
|
|
|
|
# track outputs for collation
|
|
dl_outputs.append(output)
|
|
|
|
# batch done
|
|
if test:
|
|
self.test_progress_bar.update(1)
|
|
else:
|
|
self.val_progress_bar.update(1)
|
|
outputs.append(dl_outputs)
|
|
|
|
# with a single dataloader don't pass an array
|
|
if len(dataloaders) == 1:
|
|
outputs = outputs[0]
|
|
|
|
# give model a chance to do something with the outputs (and method defined)
|
|
model = self.get_model()
|
|
if test:
|
|
eval_results_ = model.test_end(outputs)
|
|
else:
|
|
eval_results_ = model.validation_end(outputs)
|
|
eval_results = eval_results_
|
|
|
|
# enable train mode again
|
|
model.train()
|
|
|
|
# enable gradients to save memory
|
|
torch.set_grad_enabled(True)
|
|
|
|
return eval_results
|
|
|
|
def run_evaluation(self, test=False):
|
|
# when testing make sure user defined a test step
|
|
model = self.get_model()
|
|
model.on_pre_performance_check()
|
|
|
|
# select dataloaders
|
|
if test:
|
|
dataloaders = self.get_test_dataloaders()
|
|
max_batches = self.num_test_batches
|
|
else:
|
|
# val
|
|
dataloaders = self.get_val_dataloaders()
|
|
max_batches = self.num_val_batches
|
|
|
|
# init validation or test progress bar
|
|
# main progress bar will already be closed when testing so initial position is free
|
|
position = 2 * self.process_position + (not test)
|
|
desc = 'Testing' if test else 'Validating'
|
|
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
|
|
disable=not self.show_progress_bar, dynamic_ncols=True,
|
|
unit='batch', file=sys.stdout)
|
|
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
|
|
|
|
# run evaluation
|
|
eval_results = self.evaluate(self.model,
|
|
dataloaders,
|
|
max_batches,
|
|
test)
|
|
if eval_results is not None:
|
|
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
|
|
eval_results)
|
|
|
|
# add metrics to prog bar
|
|
self.add_tqdm_metrics(prog_bar_metrics)
|
|
|
|
# log metrics
|
|
self.log_metrics(log_metrics, {})
|
|
|
|
# track metrics for callbacks
|
|
self.callback_metrics.update(callback_metrics)
|
|
|
|
# hook
|
|
model.on_post_performance_check()
|
|
|
|
# add model specific metrics
|
|
tqdm_metrics = self.training_tqdm_dict
|
|
if not test:
|
|
self.main_progress_bar.set_postfix(**tqdm_metrics)
|
|
|
|
# close progress bar
|
|
if test:
|
|
self.test_progress_bar.close()
|
|
else:
|
|
self.val_progress_bar.close()
|
|
|
|
# model checkpointing
|
|
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
|
|
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
|
|
logs=self.callback_metrics)
|
|
|
|
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
|
|
# make dataloader_idx arg in validation_step optional
|
|
args = [batch, batch_idx]
|
|
# print(batch)
|
|
if test and len(self.get_test_dataloaders()) > 1:
|
|
args.append(dataloader_idx)
|
|
|
|
elif not test and len(self.get_val_dataloaders()) > 1:
|
|
args.append(dataloader_idx)
|
|
|
|
# handle DP, DDP forward
|
|
if self.use_ddp or self.use_dp:
|
|
output = model(*args)
|
|
return output
|
|
|
|
# single GPU
|
|
if self.single_gpu:
|
|
# for single GPU put inputs on gpu manually
|
|
root_gpu = 0
|
|
if isinstance(self.data_parallel_device_ids, list):
|
|
root_gpu = self.data_parallel_device_ids[0]
|
|
batch = self.transfer_batch_to_gpu(batch, root_gpu)
|
|
args[0] = batch
|
|
|
|
# CPU
|
|
if test:
|
|
output = model.test_step(*args)
|
|
else:
|
|
output = model.validation_step(*args)
|
|
|
|
return output
|
|
|
|
def train(self):
|
|
model = self.get_model()
|
|
# run all epochs
|
|
for epoch in range(self.current_epoch, 1000000):
|
|
# set seed for distributed sampler (enables shuffling for each epoch)
|
|
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
|
|
self.get_train_dataloader().sampler.set_epoch(epoch)
|
|
|
|
# get model
|
|
model = self.get_model()
|
|
|
|
# update training progress in trainer and model
|
|
model.current_epoch = epoch
|
|
self.current_epoch = epoch
|
|
|
|
total_val_batches = 0
|
|
if not self.disable_validation:
|
|
# val can be checked multiple times in epoch
|
|
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
|
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
|
|
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
|
|
total_val_batches = self.num_val_batches * val_checks_per_epoch
|
|
|
|
# total batches includes multiple val checks
|
|
self.total_batches = self.num_training_batches + total_val_batches
|
|
self.batch_loss_value = 0 # accumulated grads
|
|
|
|
if self.is_iterable_train_dataloader:
|
|
# for iterable train loader, the progress bar never ends
|
|
num_iterations = None
|
|
else:
|
|
num_iterations = self.total_batches
|
|
|
|
# reset progress bar
|
|
# .reset() doesn't work on disabled progress bar so we should check
|
|
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
|
|
self.main_progress_bar.set_description(desc)
|
|
|
|
# changing gradient according accumulation_scheduler
|
|
self.accumulation_scheduler.on_epoch_begin(epoch, self)
|
|
|
|
# -----------------
|
|
# RUN TNG EPOCH
|
|
# -----------------
|
|
self.run_training_epoch()
|
|
|
|
# update LR schedulers
|
|
if self.lr_schedulers is not None:
|
|
for lr_scheduler in self.lr_schedulers:
|
|
lr_scheduler.step(epoch=self.current_epoch)
|
|
|
|
self.main_progress_bar.close()
|
|
|
|
model.on_train_end()
|
|
|
|
if self.logger is not None:
|
|
self.logger.finalize("success")
|
|
|
|
def run_training_epoch(self):
|
|
# before epoch hook
|
|
if self.is_function_implemented('on_epoch_start'):
|
|
model = self.get_model()
|
|
model.on_epoch_start()
|
|
|
|
# run epoch
|
|
for batch_idx, batch in enumerate(self.get_train_dataloader()):
|
|
# stop epoch if we limited the number of training batches
|
|
if batch_idx >= self.num_training_batches:
|
|
break
|
|
|
|
self.batch_idx = batch_idx
|
|
|
|
model = self.get_model()
|
|
model.global_step = self.global_step
|
|
|
|
# ---------------
|
|
# RUN TRAIN STEP
|
|
# ---------------
|
|
output = self.run_training_batch(batch, batch_idx)
|
|
batch_result, grad_norm_dic, batch_step_metrics = output
|
|
|
|
# when returning -1 from train_step, we end epoch early
|
|
early_stop_epoch = batch_result == -1
|
|
|
|
# ---------------
|
|
# RUN VAL STEP
|
|
# ---------------
|
|
should_check_val = (
|
|
not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch)
|
|
self.fisrt_epoch = False
|
|
|
|
if should_check_val:
|
|
self.run_evaluation(test=self.testing)
|
|
|
|
# when logs should be saved
|
|
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
|
|
if should_save_log:
|
|
if self.proc_rank == 0 and self.logger is not None:
|
|
self.logger.save()
|
|
|
|
# when metrics should be logged
|
|
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
|
|
if should_log_metrics:
|
|
# logs user requested information to logger
|
|
self.log_metrics(batch_step_metrics, grad_norm_dic)
|
|
|
|
self.global_step += 1
|
|
self.total_batch_idx += 1
|
|
|
|
# end epoch early
|
|
# stop when the flag is changed or we've gone past the amount
|
|
# requested in the batches
|
|
if early_stop_epoch:
|
|
break
|
|
if self.global_step > self.max_updates:
|
|
print("| Training end..")
|
|
exit()
|
|
|
|
# epoch end hook
|
|
if self.is_function_implemented('on_epoch_end'):
|
|
model = self.get_model()
|
|
model.on_epoch_end()
|
|
|
|
def run_training_batch(self, batch, batch_idx):
|
|
# track grad norms
|
|
grad_norm_dic = {}
|
|
|
|
# track all metrics for callbacks
|
|
all_callback_metrics = []
|
|
|
|
# track metrics to log
|
|
all_log_metrics = []
|
|
|
|
if batch is None:
|
|
return 0, grad_norm_dic, {}
|
|
|
|
# hook
|
|
if self.is_function_implemented('on_batch_start'):
|
|
model_ref = self.get_model()
|
|
response = model_ref.on_batch_start(batch)
|
|
|
|
if response == -1:
|
|
return -1, grad_norm_dic, {}
|
|
|
|
splits = [batch]
|
|
self.hiddens = None
|
|
for split_idx, split_batch in enumerate(splits):
|
|
self.split_idx = split_idx
|
|
|
|
# call training_step once per optimizer
|
|
for opt_idx, optimizer in enumerate(self.optimizers):
|
|
if optimizer is None:
|
|
continue
|
|
# make sure only the gradients of the current optimizer's paramaters are calculated
|
|
# in the training step to prevent dangling gradients in multiple-optimizer setup.
|
|
if len(self.optimizers) > 1:
|
|
for param in self.get_model().parameters():
|
|
param.requires_grad = False
|
|
for group in optimizer.param_groups:
|
|
for param in group['params']:
|
|
param.requires_grad = True
|
|
|
|
# wrap the forward step in a closure so second order methods work
|
|
def optimizer_closure():
|
|
# forward pass
|
|
with torch.cuda.amp.autocast() if self.use_amp else contextlib.suppress():
|
|
output = self.training_forward(
|
|
split_batch, batch_idx, opt_idx, self.hiddens)
|
|
|
|
closure_loss = output[0]
|
|
progress_bar_metrics = output[1]
|
|
log_metrics = output[2]
|
|
callback_metrics = output[3]
|
|
self.hiddens = output[4]
|
|
if closure_loss is None:
|
|
return None
|
|
|
|
# accumulate loss
|
|
# (if accumulate_grad_batches = 1 no effect)
|
|
closure_loss = closure_loss / self.accumulate_grad_batches
|
|
|
|
# backward pass
|
|
model_ref = self.get_model()
|
|
if closure_loss.requires_grad:
|
|
if self.use_amp:
|
|
self.scaler.scale(closure_loss).backward()
|
|
else:
|
|
model_ref.backward(closure_loss, optimizer)
|
|
|
|
# track metrics for callbacks
|
|
all_callback_metrics.append(callback_metrics)
|
|
|
|
# track progress bar metrics
|
|
self.add_tqdm_metrics(progress_bar_metrics)
|
|
all_log_metrics.append(log_metrics)
|
|
|
|
# insert after step hook
|
|
if self.is_function_implemented('on_after_backward'):
|
|
model_ref = self.get_model()
|
|
model_ref.on_after_backward()
|
|
|
|
return closure_loss
|
|
|
|
# calculate loss
|
|
loss = optimizer_closure()
|
|
if loss is None:
|
|
continue
|
|
|
|
# nan grads
|
|
if self.print_nan_grads:
|
|
self.print_nan_gradients()
|
|
|
|
# track total loss for logging (avoid mem leaks)
|
|
self.batch_loss_value += loss.item()
|
|
|
|
# gradient update with accumulated gradients
|
|
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
|
|
|
|
# track gradient norms when requested
|
|
if batch_idx % self.row_log_interval == 0:
|
|
if self.track_grad_norm > 0:
|
|
model = self.get_model()
|
|
grad_norm_dic = model.grad_norm(
|
|
self.track_grad_norm)
|
|
|
|
# clip gradients
|
|
if self.use_amp:
|
|
self.scaler.unscale_(optimizer)
|
|
self.clip_gradients()
|
|
|
|
# calls .step(), .zero_grad()
|
|
# override function to modify this behavior
|
|
model = self.get_model()
|
|
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, self.use_amp, self.scaler)
|
|
|
|
# calculate running loss for display
|
|
self.running_loss.append(self.batch_loss_value)
|
|
self.batch_loss_value = 0
|
|
self.avg_loss = np.mean(self.running_loss[-100:])
|
|
|
|
# activate batch end hook
|
|
if self.is_function_implemented('on_batch_end'):
|
|
model = self.get_model()
|
|
model.on_batch_end()
|
|
|
|
# update progress bar
|
|
self.main_progress_bar.update(1)
|
|
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
|
|
|
|
# collapse all metrics into one dict
|
|
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
|
|
|
|
# track all metrics for callbacks
|
|
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})
|
|
|
|
return 0, grad_norm_dic, all_log_metrics
|
|
|
|
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
|
|
"""
|
|
Handle forward for each training case (distributed, single gpu, etc...)
|
|
:param batch:
|
|
:param batch_idx:
|
|
:return:
|
|
"""
|
|
# ---------------
|
|
# FORWARD
|
|
# ---------------
|
|
# enable not needing to add opt_idx to training_step
|
|
args = [batch, batch_idx, opt_idx]
|
|
|
|
# distributed forward
|
|
if self.use_ddp or self.use_dp:
|
|
output = self.model(*args)
|
|
# single GPU forward
|
|
elif self.single_gpu:
|
|
gpu_id = 0
|
|
if isinstance(self.data_parallel_device_ids, list):
|
|
gpu_id = self.data_parallel_device_ids[0]
|
|
batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
|
|
args[0] = batch
|
|
output = self.model.training_step(*args)
|
|
# CPU forward
|
|
else:
|
|
output = self.model.training_step(*args)
|
|
|
|
# allow any mode to define training_end
|
|
model_ref = self.get_model()
|
|
output_ = model_ref.training_end(output)
|
|
if output_ is not None:
|
|
output = output_
|
|
|
|
# format and reduce outputs accordingly
|
|
output = self.process_output(output, train=True)
|
|
|
|
return output
|
|
|
|
# ---------------
|
|
# Utils
|
|
# ---------------
|
|
def is_function_implemented(self, f_name):
|
|
model = self.get_model()
|
|
f_op = getattr(model, f_name, None)
|
|
return callable(f_op)
|
|
|
|
def _percent_range_check(self, name):
|
|
value = getattr(self, name)
|
|
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
|
|
if name == "val_check_interval":
|
|
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
|
|
|
|
if not 0. <= value <= 1.:
|
|
raise ValueError(msg)
|