Source code for solaris.utils.raster

import torch
import numpy as np
import tensorflow as tf


[docs]def reorder_axes(arr, target='tensorflow'): """Check order of axes in an array or tensor and convert to desired format. Arguments --------- arr : :class:`numpy.array` or :class:`torch.Tensor` or :class:`tensorflow.Tensor` target : str, optional Desired axis order type. Possible values: - ``'tensorflow'`` (default): ``[N, Y, X, C]`` or ``[Y, X, C]`` - ``'torch'`` : ``[N, C, Y, X]`` or ``[C, Y, X]`` Returns ------- out_arr : an object of the same class as `arr` with axes in the desired order. """ if isinstance(arr, torch.Tensor) or isinstance(arr, np.ndarray): axes = list(arr.shape) elif isinstance(arr, tf.Tensor): axes = arr.get_shape().as_list() if isinstance(arr, torch.Tensor): if len(axes) == 3: if target == 'tensorflow' and axes[0] < axes[1]: arr = arr.permute(1, 2, 0) elif target == 'torch' and axes[2] < axes[1]: arr = arr.permute(2, 0, 1) elif len(axes) == 4: if target == 'tensorflow' and axes[1] < axes[2]: arr = arr.permute(0, 2, 3, 1) elif target == 'torch' and axes[3] < axes[2]: arr = arr.permute(0, 3, 1, 2) elif isinstance(arr, np.ndarray): if len(axes) == 3: if target == 'tensorflow' and axes[0] < axes[1]: arr = np.moveaxis(arr, 0, -1) elif target == 'torch' and axes[2] < axes[1]: arr = np.moveaxis(arr, 2, 0) elif len(axes) == 4: if target == 'tensorflow' and axes[1] < axes[2]: arr = np.moveaxis(arr, 1, -1) elif target == 'torch' and axes[3] < axes[2]: arr = np.moveaxis(arr, 3, 1) elif isinstance(arr, tf.Tensor): # permutation is obnoxious in tensorflow; convert to numpy, permute, # convert back. np_version = arr.eval() np_version = reorder_axes(np_version, target=target) arr = tf.convert_to_tensor(np_version) return arr