Source code for solaris.nets.callbacks

import numpy as np
from tensorflow import keras
from tensorflow.keras.callbacks import Callback
from .torch_callbacks import torch_callback_dict
import torch


[docs]def get_callbacks(framework, config): """Load callbacks based on a config file for a specific framework. Usage ----- Note that this function is primarily intended for use with Keras. PyTorch does not use the same object-oriented training approach as Keras, and therefore doesn't generally have the same checkpointing objects to pass to model compilers - instead these are defined in model training code. See solaris.nets.train for examples of this. The only torch callback instantiated here is a learning rate scheduler. Arguments --------- framework : str Deep learning framework used for the model. Options are ``['keras', 'torch']`` . config : dict Configuration dict generated from the YAML config file. Returns ------- callbacks : list A `list` of callbacks to pass to the compiler (Keras) or to wrap the optimizer (torch learning rate scheduling) for model training. """ callbacks = [] if framework == 'keras': for callback, params in config['training']['callbacks'].items(): if callback == 'lr_schedule': callbacks.append(get_lr_schedule(framework, config)) else: callbacks.append(keras_callbacks[callback](**params)) elif framework == 'torch': for callback, params in config['training']['callbacks'].items(): if callback == 'lr_schedule': callbacks.append(get_lr_schedule(framework, config)) else: callbacks.append(torch_callback_dict[callback](**params)) return callbacks
[docs]class KerasTerminateOnMetricNaN(Callback): """Callback to stop training if a metric has value NaN or infinity. Notes ----- Instantiate as you would any other keras callback. For example, to end training if a validation metric called `f1_score` reaches value NaN:: m = Model(inputs, outputs) m.compile() m.fit(X, y, callbacks=[TerminateOnMetricNaN('val_f1_score')]) Attributes ---------- metric : str, optional Name of the metric being monitored. checkpoint : str, optional One of ``['epoch', 'batch']``: Should the metric be checked at the end of every epoch (default) or every batch? Methods ------- on_epoch_end : operations to complete at the end of each epoch. on_batch_end : operations to complete at the end of each batch. """ def __init__(self, metric=None, checkpoint='epoch'): """ Parameters ---------- metric (str): The name of the metric to be tested for NaN value. checkpoint (['epoch', 'batch']): Should the metric be checked at the end of every epoch (default) or every batch? """ super(KerasTerminateOnMetricNaN, self).__init__() self.metric = metric self.ckpt = checkpoint def on_epoch_end(self, epoch, logs=None): if self.ckpt == 'epoch': logs = logs or {} metric_score = logs.get(self.metric) if self.metric is not None: if np.isnan(metric_score) or np.isinf(metric_score): print('Epoch {}: Invalid score for metric {}, terminating' ' training'.format(epoch, self.metric)) self.model.stop_training = True def on_batch_end(self, batch, logs=None): if self.ckpt == 'batch': logs = logs or {} metric_score = logs.get(self.metric) print('metric score: {}'.format(metric_score)) if np.isnan(metric_score) or np.isinf(metric_score): print('Batch {}: Invalid score for metric' '{}, terminating training'.format(batch, self.metric)) self.model.stop_training = True
keras_callbacks = { 'terminate_on_nan': keras.callbacks.TerminateOnNaN, 'terminate_on_metric_nan': KerasTerminateOnMetricNaN, 'model_checkpoint': keras.callbacks.ModelCheckpoint, 'early_stopping': keras.callbacks.EarlyStopping, 'reduce_lr_on_plateau': keras.callbacks.ReduceLROnPlateau, 'csv_logger': keras.callbacks.CSVLogger }
[docs]def get_lr_schedule(framework, config): """Get a LR scheduling function for model training. Arguments --------- framework : str Deep learning framework used for the model. Options are ``['keras', 'torch']`` . config : dict Configuration dict generated from the YAML config file. Returns ------- lr_scheduler : :class:`tensorflow.keras.callbacks.LearningRateScheduler` or ``torch.optim.lr_schedule`` scheduler class A scheduler to provide during training. For Keras, this takes the form of a callback passed to the optimizer; for PyTorch, it's a class object that wraps the optimizer. Because the torch version must wrap the optimizer, it's not instantiated here - the class is returned instead. """ schedule_type = config['training'][ 'callbacks']['lr_schedule']['schedule_type'] initial_lr = config['training']['lr'] update_frequency = config['training']['callbacks']['lr_schedule'].get( 'update_frequency', 1) factor = config['training']['callbacks']['lr_schedule'].get( 'factor', 0) schedule_dict = config['training']['callbacks']['lr_schedule'].get( 'schedule_dict', None) if framework == 'keras': lr_sched_func = keras_lr_schedule(schedule_type, initial_lr, update_frequency, factor, schedule_dict) lr_scheduler = keras.callbacks.LearningRateScheduler(lr_sched_func) elif framework == 'torch': # just get the class itself to use; don't instantiate until the # optimizer has been created. if config['training'][ 'callbacks']['lr_schedule']['schedule_type'] == 'linear': lr_scheduler = torch.optim.lr_scheduler.StepLR elif config['training'][ 'callbacks']['lr_schedule']['schedule_type'] == 'exponential': lr_scheduler = torch.optim.lr_scheduler.ExponentialLR elif config['training'][ 'callbacks']['lr_schedule']['schedule_type'] == 'arbitrary': lr_scheduler = torch.optim.lr_scheduler.MultiStepLR return lr_scheduler
[docs]def keras_lr_schedule(schedule_type, initial_lr=0.001, update_frequency=1, factor=0, schedule_dict=None): """Create a learning rate schedule for Keras from a schedule dict. Arguments --------- schedule_type : str Type of learning rate schedule to use. Options are: ``['arbitrary', 'exponential', 'linear']`` . initial_lr : float, optional The initial learning rate to use. Defaults to ``0.001`` . update_frequency : int, optional How frequently should learning rate be reduced (or increased)? Defaults to ``1`` (every epoch). Has no effect if ``schedule_type='arbitrary'``. factor : float, optional The magnitude by which learning rate should be changed at each update. Use a positive number to increase learning rate and a negative number to decrease learning rate. See Usage for more details. schedule_dict : dict, optional A dictionary with ``{epoch: learning rate}`` pairs. The learning rate defined in each pair will be used beginning at the specified epoch and continuing until the next highest epoch number is reached during training. Returns ------- lr_schedule : func a function that takes epoch number integers as an argument and returns a learning rate. Usage ----- ``schedule_type='arbitrary'`` usage is documented in the arguments above. For ``schedule_type='exponential'``, the following equation is applied to determine the learning rate at each epoch: .. math:: lr = initial_lr*e^{factor\times(floor(epoch/update_frequency))} For ``schedule_type='linear'``, the following equation is applied: .. math:: lr = initial_lr\times(1+factor\times(floor(epoch/update_frequency))) """ if schedule_type == 'arbitrary': if schedule_dict is None: raise ValueError('If using an arbitrary schedule, an epoch: lr ' 'dict must be provided.') lookup_dict = {} epoch_vals = np.array(list(schedule_dict.keys())) for e in range(0, epoch_vals.max() + 1): if e < epoch_vals.min(): lookup_dict[e] = schedule_dict[epoch_vals.min()] else: # get all the epochs from the dict <= e lower_epochs = epoch_vals[epoch_vals <= e] # get the LR for the highest epoch number <= e lookup_dict[e] = schedule_dict[lower_epochs.max()] def lr_schedule(epoch): if epoch < epoch_vals.min(): return initial_lr elif epoch > epoch_vals.max(): return lookup_dict[epoch_vals.max()] else: return lookup_dict[epoch] elif schedule_type == 'exponential': def lr_schedule(epoch): if not np.floor(epoch/update_frequency): return initial_lr else: return initial_lr*factor/np.floor(epoch/update_frequency) elif schedule_type == 'linear': def lr_schedule(epoch): return initial_lr*(1+factor*np.floor(epoch/update_frequency)) return lr_schedule