Source code for solaris.preproc.image

import gdal
import math
import matplotlib.pyplot as plt
import numpy as np
import os
from osgeo import gdal_array
import pandas as pd
import uuid
import warnings

from .pipesegment import PipeSegment, LoadSegment, MergeSegment


class Image:
    def __init__(self, data, name='image', metadata={}):
        self.name = name
        self.metadata = metadata
        self.set_data(data)
    def set_data(self, data):
        if isinstance(data, np.ndarray) and data.ndim == 2:
            data = np.expand_dims(data, axis=0)
        self.data = data
    def __str__(self):
        if self.data.ndim < 3:
            raise Exception('! Image data has too few dimensions.')
        metastring = str(self.metadata)
        if len(metastring)>400:
            metastring = metastring[:360] + '...'
        return '%s: %d bands, %dx%d, %s, %s' % (self.name,
                                                *np.shape(self.data),
                                                str(self.data.dtype),
                                                metastring)


[docs]class Identity(PipeSegment): """ This class is an alias for the PipeSegment base class to emphasize its role as the identity element. """ pass
[docs]class LoadImageFromDisk(LoadSegment): """ Load an image from the file system using GDAL, so it can be fed into subsequent PipeSegments. """ def __init__(self, pathstring, name=None, verbose=False): super().__init__() self.pathstring = pathstring self.name = name self.verbose = verbose def load(self): return self.load_from_disk(self.pathstring, self.name, self.verbose) def load_from_disk(self, pathstring, name=None, verbose=False): # Use GDAL to open image file dataset = gdal.Open(pathstring) if dataset is None: raise Exception('! Image file ' + pathstring + ' not found.') data = dataset.ReadAsArray() if data.ndim == 2: data = np.expand_dims(data, axis=0) metadata = { 'geotransform': dataset.GetGeoTransform(), 'projection_ref': dataset.GetProjectionRef(), 'gcps': dataset.GetGCPs(), 'gcp_projection': dataset.GetGCPProjection(), 'meta': dataset.GetMetadata() } metadata['band_meta'] = [dataset.GetRasterBand(band).GetMetadata() for band in range(1, dataset.RasterCount+1)] if name is None: name = os.path.splitext(os.path.split(pathstring)[1])[0] dataset = None # Create an Image-class object, and return it imageobj = Image(data, name, metadata) if verbose: print(imageobj) return imageobj
[docs]class LoadImageFromMemory(LoadSegment): """ Points to an 'Image'-class image so it can be fed into subsequent PipeSegments. """ def __init__(self, imageobj, name=None, verbose=False): super().__init__() self.imageobj = imageobj self.name = name self.verbose = verbose def load(self): return self.load_from_memory(self.imageobj, self.name, self.verbose) def load_from_memory(self, imageobj, name=None, verbose=False): if type(imageobj) is not Image: raise Exception('! Invalid input type in LoadImageFromMemory.') if name is not None: imageobj.name = name if verbose: print(imageobj) return(imageobj)
[docs]class LoadImage(LoadImageFromDisk, LoadImageFromMemory): """ Makes an image available to subsequent PipeSegments, whether the image is in the filesystem (in which case 'imageinput' is the path) or an Image-class variable (in which case 'imageinput' is the variable name). """ def __init__(self, imageinput, name=None, verbose=False): PipeSegment.__init__(self) self.imageinput = imageinput self.name = name self.verbose = verbose def load(self): if type(self.imageinput) is Image: return self.load_from_memory(self.imageinput, self.name, self.verbose) elif type(self.imageinput) in (str, np.str_): return self.load_from_disk(self.imageinput, self.name, self.verbose) else: raise Exception('! Invalid input type in LoadImage.')
[docs]class SaveImage(PipeSegment): """ Save an image to disk using GDAL. """ def __init__(self, pathstring, driver='GTiff', return_image=True, save_projection=True, save_metadata=True, no_data_value=None): super().__init__() self.pathstring = pathstring self.driver = driver self.return_image = return_image self.save_projection = save_projection self.save_metadata = save_metadata self.no_data_value = no_data_value def transform(self, pin): # Save image to disk driver = gdal.GetDriverByName(self.driver) datatype = gdal_array.NumericTypeCodeToGDALTypeCode(pin.data.dtype) if datatype is None: if pin.data.dtype in (bool, np.dtype('bool')): datatype = gdal.GDT_Byte else: warnings.warn('! SaveImage did not find data type match; saving as float.') datatype = gdal.GDT_Float32 dataset = driver.Create(self.pathstring, pin.data.shape[2], pin.data.shape[1], pin.data.shape[0], datatype) for band in range(pin.data.shape[0]): bandptr = dataset.GetRasterBand(band+1) bandptr.WriteArray(pin.data[band, :, :]) if isinstance(self.no_data_value, str) \ and self.no_data_value.lower() == 'nan': bandptr.SetNoDataValue(math.nan) elif self.no_data_value is not None: bandptr.SetNoDataValue(self.no_data_value) bandptr.FlushCache() if self.save_projection: #First determine which projection system, if any, is used proj_lens = [0, 0] proj_keys = ['projection_ref', 'gcp_projection'] for i, proj_key in enumerate(proj_keys): if proj_key in pin.metadata.keys(): proj_lens[i] = len(pin.metadata[proj_key]) if proj_lens[0] > 0 and proj_lens[0] >= proj_lens[1]: dataset.SetGeoTransform(pin.metadata['geotransform']) dataset.SetProjection(pin.metadata['projection_ref']) elif proj_lens[1] > 0 and proj_lens[1] >= proj_lens[0]: dataset.SetGCPs(pin.metadata['gcps'], pin.metadata['gcp_projection']) if self.save_metadata and 'meta' in pin.metadata.keys(): dataset.SetMetadata(pin.metadata['meta']) dataset.FlushCache() # Optionally return image if self.driver.lower() == 'mem': return dataset elif self.return_image: return pin else: return None
[docs]class ShowImage(PipeSegment): """ Display an image using matplotlib. """ def __init__(self, show_text=False, show_image=True, cmap='gray', vmin=None, vmax=None, bands=None, caption=None, width=None, height=None): super().__init__() self.show_text = show_text self.show_image = show_image self.cmap = cmap self.vmin = vmin self.vmax = vmax self.bands = bands self.caption = caption self.width = width self.height = height def transform(self, pin): if self.caption is not None: print(self.caption) if self.show_text: print(pin) if self.show_image: # Select data, and format it for matplotlib if self.bands is None: image_formatted = pin.data else: image_formatted = pin.data[self.bands] pyplot_formatted = np.squeeze(np.moveaxis(image_formatted, 0, -1)) if np.ndim(pyplot_formatted)==3 and self.vmin is not None and self.vmax is not None: pyplot_formatted = np.clip((pyplot_formatted - self.vmin) / (self.vmax - self.vmin), 0., 1.) # Select image size if self.height is None and self.width is None: rc = {} elif self.height is None and self.width is not None: rc = {'figure.figsize': [self.width, self.width]} elif self.height is not None and self.width is None: rc = {'figure.figsize': [self.height, self.height]} else: rc = {'figure.figsize': [self.width, self.height]} # Show image with plt.rc_context(rc): plt.imshow(pyplot_formatted, cmap=self.cmap, vmin=self.vmin, vmax=self.vmax) plt.show() return pin
[docs]class ImageStats(PipeSegment): """ Calculate descriptive statististics about an image """ def __init__(self, print_desc=True, print_props=True, return_image=True, return_props=False, median=True, caption=None): super().__init__() self.print_desc = print_desc self.print_props = print_props self.return_image = return_image self.return_props = return_props self.median = median self.caption = caption def transform(self, pin): if self.caption is not None: print(self.caption) if self.print_desc: print(pin) print() props = pd.DataFrame({ 'min': np.nanmin(pin.data, (1,2)), 'max': np.nanmax(pin.data, (1,2)), 'mean': np.nanmean(pin.data, (1,2)), 'std': np.nanstd(pin.data, (1,2)), 'pos': np.count_nonzero(np.nan_to_num(pin.data, nan=-1.)>0, (1,2)), 'zero': np.count_nonzero(pin.data==0, (1,2)), 'neg': np.count_nonzero(np.nan_to_num(pin.data, nan=1.)<0, (1,2)), 'nan': np.count_nonzero(np.isnan(pin.data), (1,2)), }) if self.median: props.insert(3, 'median', np.nanmedian(pin.data, (1,2))) if self.print_props: print(props) print() if self.return_image and self.return_props: return (pin, props) elif self.return_image: return pin elif self.return_props: return props else: return None
[docs]class MergeToStack(PipeSegment): """ Given an iterable of equal-sized images, combine all of their bands into a single image. """ def __init__(self, master=0): super().__init__() self.master = master def transform(self, pin): # Make list of all the input bands datalist = [imageobj.data for imageobj in pin] # Create output image, using name and metadata from designated source pout = Image(None, pin[self.master].name, pin[self.master].metadata) pout.data = np.concatenate(datalist, axis=0) return pout
[docs]class MergeToSum(PipeSegment): """ Combine an iterable of images by summing the corresponding bands. Assumes that images are of equal size and have equal numbers of bands. """ def __init__(self, master=0): super().__init__() self.master = master def transform(self, pin): total = pin[self.master].data.copy() for i in range(len(pin)): if not i == self.master: total += pin[i].data return Image(total, pin[self.master].name, pin[self.master].metadata)
[docs]class MergeToProduct(PipeSegment): """ Combine an iterable of images by multiplying the corresponding bands. Assumes that images are of equal size and have equal numbers of bands. """ def __init__(self, master=0): super().__init__() self.master = master def transform(self, pin): product = pin[self.master].data.copy() for i in range(len(pin)): if not i == self.master: product *= pin[i].data return Image(product, pin[self.master].name, pin[self.master].metadata)
[docs]class SelectItem(PipeSegment): """ Given an iterable, return one of its items. This is useful when passing a list of items into, or out of, a custom class. """ def __init__(self, index=0): super().__init__() self.index = index def transform(self, pin): return pin[self.index]
[docs]class SelectBands(PipeSegment): """ Reorganize the bands in an image. This class can be used to select, delete, duplicate, or reorder bands. """ def __init__(self, bands=[0]): super().__init__() if not hasattr(bands, '__iter__'): bands = [bands] self.bands = bands def transform(self, pin): return Image(pin.data[self.bands, :, :], pin.name, pin.metadata)
[docs]class Bounds(PipeSegment): """ Output the boundary coordinates [xmin, ymin, xmax, ymax] of an image. Note: Requires the image to have an affine geotransform, not GCPs. Note: Only works for a north-up image without rotation or shearing """ def transform(self, pin): gt = pin.metadata['geotransform'] numrows = pin.data.shape[1] numcols = pin.data.shape[2] bounds = [gt[0], gt[3] + gt[5]*numrows, gt[0] + gt[1]*numcols, gt[3]] return bounds
[docs]class Scale(PipeSegment): """ Scale data by a multiplicative factor. """ def __init__(self, factor=1.): super().__init__() self.factor = factor def transform(self, pin): return Image(self.factor * pin.data, pin.name, pin.metadata)
[docs]class Crop(PipeSegment): """ Crop image based on either pixel coordinates or georeferenced coordinates. 'bounds' is a list specifying the edges: [left, bottom, right, top] """ def __init__(self, bounds, mode='pixel'): super().__init__() self.bounds = bounds self.mode = mode def transform(self, pin): row_min = self.bounds[3] row_max = self.bounds[1] col_min = self.bounds[0] col_max = self.bounds[2] if self.mode in ['pixel', 'p', 0]: srcWin = [col_min, row_min, col_max - col_min + 1, row_max - row_min + 1] projWin = None elif self.mode in ['geo', 'g', 1]: srcWin = None projWin = [col_min, row_min, col_max, row_max] else: raise Exception('! Invalid mode in Crop') drivername = 'GTiff' srcpath = '/vsimem/crop_input_' + str(uuid.uuid4()) + '.tif' dstpath = '/vsimem/crop_output_' + str(uuid.uuid4()) + '.tif' (pin * SaveImage(srcpath, driver=drivername))() gdal.Translate(dstpath, srcpath, srcWin=srcWin, projWin=projWin) pout = LoadImage(dstpath)() pout.name = pin.name if pin.data.dtype in (bool, np.dtype('bool')): pout.data = pout.data.astype('bool') driver = gdal.GetDriverByName(drivername) driver.Delete(srcpath) driver.Delete(dstpath) return pout
[docs]class CropVariable(Crop): """ Like 'Crop', but window coordinates are accepted from another PipeSegment at runtime instead of via initialization arguments. """ def __init__(self, mode='pixel'): PipeSegment.__init__(self) self.mode = mode def transform(self, pin): imagetocrop = pin[0] self.bounds = pin[1] return super().transform(imagetocrop)
[docs]class Resize(PipeSegment): """ Resize an image to the requested number of pixels """ def __init__(self, rows, cols): super().__init__() self.rows = rows self.cols = cols def transform(self, pin): return self.resize(pin, self.rows, self.cols) def resize(self, pin, rows, cols): drivername = 'GTiff' srcpath = '/vsimem/resize_input_' + str(uuid.uuid4()) + '.tif' dstpath = '/vsimem/resize_output_' + str(uuid.uuid4()) + '.tif' (pin * SaveImage(srcpath, driver=drivername))() gdal.Translate(dstpath, srcpath, width=cols, height=rows) pout = LoadImage(dstpath)() pout.name = pin.name if pin.data.dtype in (bool, np.dtype('bool')): pout.data = pout.data.astype('bool') driver = gdal.GetDriverByName(drivername) driver.Delete(srcpath) driver.Delete(dstpath) return pout
[docs]class GetMask(PipeSegment): """ Extract a Boolean mask from an image band. NaN is assumed to be the mask value, unless otherwise specified. """ def __init__(self, band=0, flag='nan'): super().__init__() self.band = band self.flag = flag def transform(self, pin): if self.flag == 'nan': data = np.expand_dims(np.invert(np.isnan(pin.data[self.band])), axis=0) else: data = np.expand_dims(pin.data[self.band]==self.flag, axis=0) return Image(data, pin.name, pin.metadata)
[docs]class SetMask(PipeSegment): """ Given an image and a mask, apply the mask to the image. More specifically, set the image's pixel value to NaN (or other specified value) for every pixel where the mask value is False. """ def __init__(self, flag=math.nan, band=None, reverse_order=False): super().__init__() self.flag = flag self.band = band self.reverse_order = reverse_order def transform(self, pin): if not self.reverse_order: img = pin[0] mask = pin[1] else: img = pin[1] mask = pin[0] mark = np.invert(np.squeeze(mask.data)) data = np.copy(img.data) if self.band is None: data[:, mark] = self.flag else: data[self.band, mark] = self.flag return Image(data, img.name, img.metadata)
[docs]class InvertMask(PipeSegment): """ Sets all True values in a mask to False and vice versa. """ def transform(self, pin): return Image(np.invert(pin.data), pin.name, pin.metadata)