"""Training code for `solaris` models."""
import numpy as np
import pandas as pd
from .model_io import get_model, reset_weights
from .datagen import make_data_generator
from .losses import get_loss
from .optimizers import get_optimizer
from .callbacks import get_callbacks
from .torch_callbacks import TorchEarlyStopping, TorchTerminateOnNaN
from .torch_callbacks import TorchModelCheckpoint
from .metrics import get_metrics
import torch
from torch.optim.lr_scheduler import _LRScheduler
import tensorflow as tf
[docs]class Trainer(object):
"""Object for training `solaris` models using PyTorch or Keras. """
def __init__(self, config, custom_model_dict=None, custom_losses=None):
self.config = config
self.pretrained = self.config['pretrained']
self.batch_size = self.config['batch_size']
self.framework = self.config['nn_framework']
self.model_name = self.config['model_name']
self.model_path = self.config.get('model_path', None)
try:
self.num_classes = self.config['data_specs']['num_classes']
except KeyError:
self.num_classes = 1
self.model = get_model(self.model_name, self.framework,
self.model_path, self.pretrained,
custom_model_dict, self.num_classes)
self.train_df, self.val_df = get_train_val_dfs(self.config)
self.train_datagen = make_data_generator(self.framework, self.config,
self.train_df, stage='train')
self.val_datagen = make_data_generator(self.framework, self.config,
self.val_df, stage='validate')
self.epochs = self.config['training']['epochs']
self.optimizer = get_optimizer(self.framework, self.config)
self.lr = self.config['training']['lr']
self.custom_losses = custom_losses
self.loss = get_loss(self.framework,
self.config['training'].get('loss'),
self.config['training'].get('loss_weights'),
self.custom_losses)
self.checkpoint_frequency = self.config['training'].get('checkpoint_'
+ 'frequency')
self.callbacks = get_callbacks(self.framework, self.config)
self.metrics = get_metrics(self.framework, self.config)
self.verbose = self.config['training']['verbose']
if self.framework in ['torch', 'pytorch']:
self.gpu_available = torch.cuda.is_available()
if self.gpu_available:
self.gpu_count = torch.cuda.device_count()
else:
self.gpu_count = 0
elif self.framework == 'keras':
self.gpu_available = tf.test.is_gpu_available()
self.is_initialized = False
self.stop = False
self.initialize_model()
[docs] def initialize_model(self):
"""Load in and create all model training elements."""
if not self.pretrained:
self.model = reset_weights(self.model, self.framework)
if self.framework == 'keras':
self.model = self.model.compile(optimizer=self.optimizer,
loss=self.loss,
metrics=self.metrics['train'])
elif self.framework == 'torch':
if self.gpu_available:
self.model = self.model.cuda()
if self.gpu_count > 1:
self.model = torch.nn.DataParallel(self.model)
# create optimizer
if self.config['training']['opt_args'] is not None:
self.optimizer = self.optimizer(
self.model.parameters(), lr=self.lr,
**self.config['training']['opt_args']
)
else:
self.optimizer = self.optimizer(
self.model.parameters(), lr=self.lr
)
# wrap in lr_scheduler if one was created
for cb in self.callbacks:
if isinstance(cb, _LRScheduler):
self.optimizer = cb(
self.optimizer,
**self.config['training']['callbacks'][
'lr_schedule'].get(['schedule_dict'], {})
)
# drop the LRScheduler callback from the list
self.callbacks = [i for i in self.callbacks if i != cb]
self.is_initialized = True
[docs] def train(self):
"""Run training on the model."""
if not self.is_initialized:
self.initialize_model()
if self.framework == 'keras':
self.model.fit_generator(self.train_datagen,
validation_data=self.val_datagen,
epochs=self.epochs,
callbacks=self.callbacks)
elif self.framework == 'torch':
# tf_sess = tf.Session()
for epoch in range(self.epochs):
if self.verbose:
print('Beginning training epoch {}'.format(epoch))
# TRAINING
self.model.train()
for batch_idx, batch in enumerate(self.train_datagen):
if torch.cuda.is_available():
if self.config['data_specs'].get('additional_inputs',
None) is not None:
data = []
for i in ['image'] + self.config[
'data_specs']['additional_inputs']:
data.append(torch.Tensor(batch[i]).cuda())
else:
data = batch['image'].cuda()
target = batch['mask'].cuda().float()
else:
if self.config['data_specs'].get('additional_inputs',
None) is not None:
data = []
for i in ['image'] + self.config[
'data_specs']['additional_inputs']:
data.append(torch.Tensor(batch[i]))
else:
data = batch['image']
target = batch['mask'].float()
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss(output, target)
loss.backward()
self.optimizer.step()
if self.verbose and batch_idx % 10 == 0:
print(' loss at batch {}: {}'.format(
batch_idx, loss), flush=True)
# VALIDATION
with torch.no_grad():
self.model.eval()
torch.cuda.empty_cache()
val_loss = []
for batch_idx, batch in enumerate(self.val_datagen):
if torch.cuda.is_available():
if self.config['data_specs'].get(
'additional_inputs', None) is not None:
data = []
for i in ['image'] + self.config[
'data_specs']['additional_inputs']:
data.append(torch.Tensor(batch[i]).cuda())
else:
data = batch['image'].cuda()
target = batch['mask'].cuda().float()
else:
if self.config['data_specs'].get(
'additional_inputs', None) is not None:
data = []
for i in ['image'] + self.config[
'data_specs']['additional_inputs']:
data.append(torch.Tensor(batch[i]))
else:
data = batch['image']
target = batch['mask'].float()
val_output = self.model(data)
val_loss.append(self.loss(val_output, target))
val_loss = torch.mean(torch.stack(val_loss))
if self.verbose:
print()
print(' Validation loss at epoch {}: {}'.format(
epoch, val_loss))
print()
check_continue = self._run_torch_callbacks(
loss.detach().cpu().numpy(),
val_loss.detach().cpu().numpy())
if not check_continue:
break
self.save_model()
def _run_torch_callbacks(self, loss, val_loss):
for cb in self.callbacks:
if isinstance(cb, TorchEarlyStopping):
cb(val_loss)
if cb.stop:
if self.verbose:
print('Early stopping triggered - '
'ending training')
return False
elif isinstance(cb, TorchTerminateOnNaN):
cb(val_loss)
if cb.stop:
if self.verbose:
print('Early stopping triggered - '
'ending training')
return False
elif isinstance(cb, TorchModelCheckpoint):
# set minimum num of epochs btwn checkpoints (not periodic)
# or
# frequency of model saving (periodic)
# cb.period = self.checkpoint_frequency
if cb.monitor == 'loss':
cb(self.model, loss_value=loss)
elif cb.monitor == 'val_loss':
cb(self.model, loss_value=val_loss)
elif cb.monitor == 'periodic':
# no loss_value specification needed; defaults to `loss`
# cb(self.model, loss_value=loss)
cb(self.model)
return True
[docs] def save_model(self):
"""Save the final model output."""
if self.framework == 'keras':
self.model.save(self.config['training']['model_dest_path'])
elif self.framework == 'torch':
if isinstance(self.model, torch.nn.DataParallel):
torch.save(self.model.module.state_dict(),
self.config['training']['model_dest_path'])
else:
torch.save(self.model.state_dict(),
self.config['training']['model_dest_path'])
[docs]def get_train_val_dfs(config):
"""Get the training and validation dfs based on the contents of ``config``.
This function uses the logic described in the documentation for the config
files to determine where to find training and validation dataset files.
See the docs and the comments in solaris/data/config_skeleton.yml for
details.
Arguments
---------
config : dict
The loaded configuration dict for model training and/or inference.
Returns
-------
train_df, val_df : :class:`tuple` of :class:`dict` s
:class:`dict` s containing two columns: ``'image'`` and ``'label'``.
Each column corresponds to paths to find matching image and label files
for training.
"""
train_df = pd.read_csv(config['training_data_csv'])
if config['data_specs']['val_holdout_frac'] is None:
if config['validation_data_csv'] is None:
raise ValueError(
"If val_holdout_frac isn't specified in config,"
" validation_data_csv must be.")
val_df = pd.read_csv(config['validation_data_csv'])
else:
val_frac = config['data_specs']['val_holdout_frac']
val_subset = np.random.choice(train_df.index,
int(len(train_df)*val_frac),
replace=False)
val_df = train_df.loc[val_subset]
# remove the validation samples from the training df
train_df = train_df.drop(index=val_subset)
return train_df, val_df