Source code for solaris.nets.model_io

import os
from tensorflow import keras
import torch
from warnings import warn
import requests
import numpy as np
from tqdm.auto import tqdm
from ..nets import weights_dir
from .zoo import model_dict


[docs]def get_model(model_name, framework, model_path=None, pretrained=False, custom_model_dict=None, num_classes=1): """Load a model from a file based on its name.""" if custom_model_dict is not None: md = custom_model_dict else: md = model_dict.get(model_name, None) if md is None: # if the model's not provided by solaris raise ValueError(f"{model_name} can't be found in solaris and no " "custom_model_dict was provided. Check your " "model_name in the config file and/or provide a " "custom_model_dict argument to Trainer(). ") if model_path is None or custom_model_dict is not None: model_path = md.get('weight_path') if num_classes == 1: model = md.get('arch')(pretrained=pretrained) else: model = md.get('arch')(num_classes=num_classes, pretrained=pretrained) if model is not None and pretrained: try: model = _load_model_weights(model, model_path, framework) except (OSError, FileNotFoundError): warn(f'The model weights file {model_path} was not found.' ' Attempting to download from the SpaceNet repository.') weight_path = _download_weights(md) model = _load_model_weights(model, weight_path, framework) return model
def _load_model_weights(model, path, framework): """Backend for loading the model.""" if framework.lower() == 'keras': try: model.load_weights(path) except OSError: # first, check to see if the weights are in the default sol dir default_path = os.path.join(weights_dir, os.path.split(path)[1]) try: model.load_weights(default_path) except OSError: # if they can't be found anywhere, raise the error. raise FileNotFoundError("{} doesn't exist.".format(path)) elif framework.lower() in ['torch', 'pytorch']: # pytorch already throws the right error on failed load, so no need # to fix exception if torch.cuda.is_available(): try: loaded = torch.load(path) except FileNotFoundError: # first, check to see if the weights are in the default sol dir default_path = os.path.join(weights_dir, os.path.split(path)[1]) loaded = torch.load(path) else: try: loaded = torch.load(path, map_location='cpu') except FileNotFoundError: default_path = os.path.join(weights_dir, os.path.split(path)[1]) loaded = torch.load(path, map_location='cpu') if isinstance(loaded, torch.nn.Module): # if it's a full model already model.load_state_dict(loaded.state_dict()) else: model.load_state_dict(loaded) return model
[docs]def reset_weights(model, framework): """Re-initialize model weights for training. Arguments --------- model : :class:`tensorflow.keras.Model` or :class:`torch.nn.Module` A pre-trained, compiled model with weights saved. framework : str The deep learning framework used. Currently valid options are ``['torch', 'keras']`` . Returns ------- reinit_model : model object The model with weights re-initialized. Note this model object will also lack an optimizer, loss function, etc., which will need to be added. """ if framework == 'keras': model_json = model.to_json() reinit_model = keras.models.model_from_json(model_json) elif framework == 'torch': reinit_model = model.apply(_reset_torch_weights) return reinit_model
def _reset_torch_weights(torch_layer): if isinstance(torch_layer, torch.nn.Conv2d) or \ isinstance(torch_layer, torch.nn.Linear): torch_layer.reset_parameters() def _download_weights(model_dict): """Download pretrained weights for a model.""" weight_url = model_dict.get('weight_url', None) weight_dest_path = model_dict.get('weight_path', os.path.join( weights_dir, weight_url.split('/')[-1])) if weight_url is None: raise KeyError("Can't find the weights file.") else: r = requests.get(weight_url, stream=True) if r.status_code != 200: raise ValueError('The file could not be downloaded. Check the URL' ' and network connections.') total_size = int(r.headers.get('content-length', 0)) block_size = 1024 with open(weight_dest_path, 'wb') as f: for chunk in tqdm(r.iter_content(block_size), total=np.ceil(total_size//block_size), unit='KB', unit_scale=False): if chunk: f.write(chunk) return weight_dest_path