Source code for solaris.utils.core
import os
import numpy as np
from shapely.wkt import loads
from shapely.geometry import Point
from shapely.geometry.base import BaseGeometry
import pandas as pd
import geopandas as gpd
import pyproj
import rasterio
from distutils.version import LooseVersion
import skimage
from fiona._err import CPLE_OpenFailedError
from fiona.errors import DriverError
from warnings import warn
def _check_rasterio_im_load(im):
"""Check if `im` is already loaded in; if not, load it in."""
if isinstance(im, str):
return rasterio.open(im)
elif isinstance(im, rasterio.DatasetReader):
return im
else:
raise ValueError(
"{} is not an accepted image format for rasterio.".format(im))
def _check_skimage_im_load(im):
"""Check if `im` is already loaded in; if not, load it in."""
if isinstance(im, str):
return skimage.io.imread(im)
elif isinstance(im, np.ndarray):
return im
else:
raise ValueError(
"{} is not an accepted image format for scikit-image.".format(im))
def _check_df_load(df):
"""Check if `df` is already loaded in, if not, load from file."""
if isinstance(df, str):
if df.lower().endswith('json'):
return _check_gdf_load(df)
else:
return pd.read_csv(df)
elif isinstance(df, pd.DataFrame):
return df
else:
raise ValueError(f"{df} is not an accepted DataFrame format.")
def _check_gdf_load(gdf):
"""Check if `gdf` is already loaded in, if not, load from geojson."""
if isinstance(gdf, str):
# as of geopandas 0.6.2, using the OGR CSV driver requires some add'nal
# kwargs to create a valid geodataframe with a geometry column. see
# https://github.com/geopandas/geopandas/issues/1234
if gdf.lower().endswith('csv'):
return gpd.read_file(gdf, GEOM_POSSIBLE_NAMES="geometry",
KEEP_GEOM_COLUMNS="NO")
try:
return gpd.read_file(gdf)
except (DriverError, CPLE_OpenFailedError):
warn(f"GeoDataFrame couldn't be loaded: either {gdf} isn't a valid"
" path or it isn't a valid vector file. Returning an empty"
" GeoDataFrame.")
return gpd.GeoDataFrame()
elif isinstance(gdf, gpd.GeoDataFrame):
return gdf
else:
raise ValueError(f"{gdf} is not an accepted GeoDataFrame format.")
def _check_geom(geom):
"""Check if a geometry is loaded in.
Returns the geometry if it's a shapely geometry object. If it's a wkt
string or a list of coordinates, convert to a shapely geometry.
"""
if isinstance(geom, BaseGeometry):
return geom
elif isinstance(geom, str): # assume it's a wkt
return loads(geom)
elif isinstance(geom, list) and len(geom) == 2: # coordinates
return Point(geom)
def _check_crs(input_crs, return_rasterio=False):
"""Convert CRS to the ``pyproj.CRS`` object passed by ``solaris``."""
if not isinstance(input_crs, pyproj.CRS) and input_crs is not None:
out_crs = pyproj.CRS(input_crs)
else:
out_crs = input_crs
if return_rasterio:
if LooseVersion(rasterio.__gdal_version__) >= LooseVersion("3.0.0"):
out_crs = rasterio.crs.CRS.from_wkt(out_crs.to_wkt())
else:
out_crs = rasterio.crs.CRS.from_wkt(out_crs.to_wkt("WKT1_GDAL"))
return out_crs
[docs]def get_data_paths(path, infer=False):
"""Get a pandas dataframe of images and labels from a csv.
This file is designed to parse image:label reference CSVs (or just image)
for inferencde) as defined in the documentation. Briefly, these should be
CSVs containing two columns:
``'image'``: the path to images.
``'label'``: the path to the label file that corresponds to the image.
Arguments
---------
path : str
Path to a .CSV-formatted reference file defining the location of
training, validation, or inference data. See docs for details.
infer : bool, optional
If ``infer=True`` , the ``'label'`` column will not be returned (as it
is unnecessary for inference), even if it is present.
Returns
-------
df : :class:`pandas.DataFrame`
A :class:`pandas.DataFrame` containing the relevant `image` and `label`
information from the CSV at `path` (unless ``infer=True`` , in which
case only the `image` column is returned.)
"""
df = pd.read_csv(path)
if infer:
return df[['image']] # no labels in those files
else:
return df[['image', 'label']] # remove anything extraneous
[docs]def get_files_recursively(path, traverse_subdirs=False, extension='.tif'):
"""Get files from subdirs of `path`, joining them to the dir."""
if traverse_subdirs:
walker = os.walk(path)
path_list = []
for step in walker:
if not step[2]: # if there are no files in the current dir
continue
path_list += [os.path.join(step[0], fname)
for fname in step[2] if
fname.lower().endswith(extension)]
return path_list
else:
return [os.path.join(path, f) for f in os.listdir(path)
if f.endswith(extension)]