{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **Imports**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data download & file management...\n",
    "import requests\n",
    "import os\n",
    "import glob\n",
    "import zipfile\n",
    "import json\n",
    "\n",
    "# data manipulation...\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.ndimage import gaussian_filter\n",
    "from math import ceil\n",
    "\n",
    "# geospatial vector...\n",
    "import geopandas as gpd\n",
    "import fiona\n",
    "from shapely.geometry import box\n",
    "\n",
    "# geospatial image...\n",
    "import rasterio\n",
    "from rasterio.plot import show\n",
    "from rasterio.windows import from_bounds\n",
    "from rasterio.warp import calculate_default_transform, reproject, Resampling\n",
    "from rasterio.merge import merge\n",
    "from rasterio.mask import mask\n",
    "from rasterio.transform import from_origin\n",
    "from rasterio.features import rasterize\n",
    "\n",
    "# plotting...\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **Data Download Functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_zip(url, output_dir):\n",
    "    \"\"\"\n",
    "    Function to download zip file, extract contents in the specified directory, and delete the zip file.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    url : str\n",
    "        Download URL for zip file.\n",
    "    output_dir : str\n",
    "        Directory path to save zip file and extract contents.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    try:\n",
    "        response = requests.get(url)\n",
    "        response.raise_for_status()\n",
    "        zip_path = os.path.join(output_dir, 'download.zip')\n",
    "        if response.status_code == 200:\n",
    "            with open(zip_path, 'wb') as zip:\n",
    "                zip.write(response.content)\n",
    "            with zipfile.ZipFile(zip_path, 'r') as zip:\n",
    "                zip.extractall(output_dir)\n",
    "            os.remove(zip_path)\n",
    "        else:\n",
    "            print('Reponse code not 200 for downloading .zip...')\n",
    "    except:\n",
    "        print('Error downloading .zip...')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_tif(url, output_path):\n",
    "    \"\"\"\n",
    "    Function to download TIFF file from a specified URL.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    url : str\n",
    "        Download URL for GeoTIFF file.\n",
    "    output_path : str\n",
    "        Path to save GeoTIFF.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    try:\n",
    "        response = requests.get(url)\n",
    "        response.raise_for_status()\n",
    "        if response.status_code == 200:\n",
    "            with open(output_path, 'wb') as tif:\n",
    "                tif.write(response.content)\n",
    "        else:\n",
    "            print('Reponse code for URL not 200...')\n",
    "    except:\n",
    "        print(f\"Error connecting to URL...\\n{url}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_data_tiles(index_path, id_field, url_field, output_dir):\n",
    "    \"\"\"\n",
    "    Function to read KyFromAbove Tile Index GeoJSON, download relevant GeoTIFFs using the download URLs from a specified attribute, and then save each GeoTIFF to the specified output directory.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    index_path : str\n",
    "        Path to GeoJSON.\n",
    "    id_field : str\n",
    "        Attribute name of GeoJSON containing unique ID for file naming.\n",
    "    url_field : str\n",
    "        Attribute name of GeoJSON containing the download URL.\n",
    "    output_dir : str\n",
    "        Directory where TIFF(s) will be downloaded.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    gdf = gpd.read_file(index_path)\n",
    "    \n",
    "    for _, tile in gdf.iterrows():\n",
    "        tile_id = tile[id_field]\n",
    "        url = tile[url_field]\n",
    "        content_type = url[-3:]\n",
    "\n",
    "        if len(glob.glob(f\"{output_dir}/*{tile_id}*\")) > 0:\n",
    "            continue\n",
    "\n",
    "        if content_type == 'tif':\n",
    "            output_path = f\"{output_dir}/{tile_id}.tif\"\n",
    "            download_tif(url, output_path)\n",
    "\n",
    "        elif content_type == 'zip':\n",
    "            download_zip(url, output_dir)\n",
    "\n",
    "        else:\n",
    "            print('Download is not .tif or .zip...')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_aoi_index_polygons(input_path, boundary_path, output_dir):\n",
    "\n",
    "    # read buffered boundary into geodataframe\n",
    "    boundary = gpd.read_file(boundary_path)\n",
    "\n",
    "    # get list of layers in index geodatabase\n",
    "    index_layers = fiona.listlayers(input_path)\n",
    "\n",
    "    # iterate through layers\n",
    "    for index in index_layers:\n",
    "        \n",
    "        # extract dem index\n",
    "        if 'dem' in index.lower():\n",
    "\n",
    "            # read dem index as geodataframe\n",
    "            dem_index = gpd.read_file(input_path, layer=index)\n",
    "\n",
    "            # perform spatial join between buffered boundary & statewide index (only tiles that intersect index)\n",
    "            intersect = gpd.sjoin(left_df=dem_index, right_df=boundary, how='inner')\n",
    "\n",
    "            # define output path for dem index\n",
    "            output_path = f\"{output_dir}/dem_index.geojson\"\n",
    "\n",
    "            # write selected tiles to GeoJSON\n",
    "            if not os.path.isfile(output_path):\n",
    "                intersect.to_file(output_path, driver='GeoJSON')\n",
    "        \n",
    "        # extract aerial imagery index\n",
    "        elif 'aerial' in index.lower():\n",
    "            aerial_index = gpd.read_file(input_path, layer=index)\n",
    "            intersect = gpd.sjoin(left_df=aerial_index, right_df=boundary, how='inner')\n",
    "            output_path = f\"{output_dir}/aerial_index.geojson\"\n",
    "            if not os.path.isfile(output_path):\n",
    "                intersect.to_file(output_path, driver='GeoJSON')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **GIS Vector Manipulation Functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gis_to_image(input_path, output_path, output_resolution, attribute):\n",
    "    \"\"\"\n",
    "    Function to convert vector geospatial file to GeoTIFF image file with a given resolution and categorical attribute. Output GeoTIFF file is of float32 dtype with NaN representing nodata values.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    input_path : str\n",
    "        Path to input GeoJSON or Shapefile.\n",
    "    output_path : str\n",
    "        Path for output GeoTIFF.\n",
    "    output_resolution : int\n",
    "        Resolution of GeoTIFF in native spatial units of input GIS file.\n",
    "    attribute : str\n",
    "        Name of categorical attribute in GIS file for assigning pixel values.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    # read input GIS file as geodataframe\n",
    "    gdf = gpd.read_file(input_path)\n",
    "\n",
    "    # if input is polygon or multipolygon, then apply 0 buffer to mitigate potential geometry errors\n",
    "    if gdf.geom_type.isin(['Polygon', 'MultiPolygon']).any():\n",
    "        gdf['geometry'] = gdf['geometry'].buffer(0.1)\n",
    "    \n",
    "    # get bounding coordinates & output width and height (using desired resolution)\n",
    "    minx, miny, maxx, maxy = gdf.total_bounds\n",
    "    width = ceil((maxx - minx) / output_resolution)\n",
    "    height = ceil((maxy - miny) / output_resolution)\n",
    "\n",
    "    # calculate transform for output image\n",
    "    transform = from_origin(west=minx, north=maxy, xsize=output_resolution, ysize=output_resolution)\n",
    "\n",
    "    mapper = {'af1': 1, 'Qal': 2, 'Qaf': 3, 'Qat': 4, 'Qc': 5, 'Qca': 6, 'Qr': 7}\n",
    "\n",
    "    # create new geodataframe attribute of categorical integer assignments\n",
    "    gdf[f\"{attribute}_int\"] = gdf[attribute].apply(lambda x: mapper.get(x, np.nan))\n",
    "\n",
    "\n",
    "    # get list of geometries and associated values\n",
    "    shapes = [(geom, value) for geom, value in zip(gdf.geometry, gdf[f\"{attribute}_int\"])]\n",
    "    \n",
    "    # rasterize shapes using output height, width, and transform\n",
    "    output_image = rasterize(shapes = shapes, \n",
    "                             out_shape = (height, width), \n",
    "                             transform = transform, \n",
    "                             all_touched = True, \n",
    "                             fill = np.nan, \n",
    "                             dtype = rasterio.float32)\n",
    "    \n",
    "    # create metadata for output image\n",
    "    output_meta = {'driver': 'GTiff', \n",
    "                   'height': height, \n",
    "                   'width': width, \n",
    "                   'transform': transform, \n",
    "                   'count': 1, \n",
    "                   'dtype': output_image.dtype, \n",
    "                   'nodata': np.nan, \n",
    "                   'crs': gdf.crs.to_string()}\n",
    "    \n",
    "    # write image and metadata to GeoTIFF\n",
    "    with rasterio.open(output_path, 'w', **output_meta) as dst:\n",
    "        dst.write(output_image, 1)\n",
    "    \n",
    "    # write mapping dictionary of integers and categories to JSON\n",
    "    output_json_path = output_path.replace('.tif', '.json')\n",
    "    with open(output_json_path, 'w') as file:\n",
    "        json.dump(mapper, file, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clip_gis_to_boundary(input_path, boundary_path, output_path, gdb_layer=None):\n",
    "    \"\"\"\n",
    "    Function to clip GIS spatial data to the extent of an area of interest polygon and save the clipped feature(s) as a new GeoJSON file.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    input_path : str\n",
    "        Path to GIS spatial input file. If this is a geodatabase (.gdb), then the gdb_layer argument must be specified.\n",
    "    boundary_path : str\n",
    "        Path to area of interest polygon.\n",
    "    output_path : str\n",
    "        Path for output GeoJSON.\n",
    "    gdb_layer : str (optional)\n",
    "        Name of geodatabase layer to be clipped. Default is None.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    if not gdb_layer:\n",
    "        gdf_input = gpd.read_file(input_path)\n",
    "    else:\n",
    "        gdf_input = gpd.read_file(input_path, layer=gdb_layer)\n",
    "    gdf_input = gdf_input.explode(ignore_index=True, index_parts=False)\n",
    "    gdf_boundary = gpd.read_file(boundary_path)\n",
    "\n",
    "    if gdf_input.crs != gdf_boundary.crs:\n",
    "        gdf_input = gdf_input.to_crs(gdf_boundary.crs)\n",
    "\n",
    "    gdf_output = gpd.clip(gdf_input, mask=gdf_boundary)\n",
    "    gdf_output.to_file(output_path, driver='GeoJSON')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multiple_gis_to_reference_image(input_paths, reference_path, output_path, binary=True):\n",
    "    \"\"\"\n",
    "    Function to combine multiple geospatial vector GIS features into a new GeoTIFF image aligned with a reference image. In the case of overlapping features, priority for pixel values in the final image will be given to the last feature. Background space will be given a value of 0 and additional features will be given sequential integers in increments of 1.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    input_paths : list or tuple\n",
    "        List of path to vector GIS features in GeoJSON(s) and/or Shapefile(s).\n",
    "    reference_path : str\n",
    "        Path to reference GeoTIFF image.\n",
    "    output_path : str\n",
    "        Path to output GeoTIFF image.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    with rasterio.open(reference_path) as src:\n",
    "\n",
    "        shapes_all = []\n",
    "        features = ['background']\n",
    "\n",
    "        for val, path in enumerate(input_paths, start=1):\n",
    "            \n",
    "            feature = os.path.basename(path)\n",
    "            feature = os.path.splitext(feature)[0]\n",
    "            features.append(feature)\n",
    "\n",
    "            gdf = gpd.read_file(path)\n",
    "\n",
    "            if gdf.crs != src.crs:\n",
    "                gdf = gdf.to_crs(src.crs)\n",
    "            \n",
    "            if not binary:\n",
    "                shapes = [(geom, val) for geom in gdf.geometry]\n",
    "            else:\n",
    "                shapes = [(geom, 1) for geom in gdf.geometry]\n",
    "\n",
    "            shapes_all.extend(shapes)\n",
    "\n",
    "        output_image = rasterize(shapes=shapes_all, \n",
    "                                 out_shape=(src.height, src.width), \n",
    "                                 transform=src.transform, \n",
    "                                 fill=0, \n",
    "                                 all_touched=True, \n",
    "                                 dtype=rasterio.float32)\n",
    "        \n",
    "        mask = src.dataset_mask()\n",
    "        output_image = np.where(mask, output_image, src.nodata)\n",
    "\n",
    "        output_meta = src.meta.copy()\n",
    "        output_meta.update({'driver': 'GTiff', \n",
    "                            'count': 1, \n",
    "                            'dtype':rasterio.float32})\n",
    "        \n",
    "        with rasterio.open(output_path, 'w', **output_meta) as dst:\n",
    "            dst.write(output_image.astype(rasterio.float32), 1)\n",
    "        \n",
    "        if not binary:\n",
    "            mapper = {k:v for v,k in enumerate(features)}\n",
    "            output_json_path = output_path.replace('.tif', '.json')\n",
    "            with open(output_json_path, 'w') as meta:\n",
    "                json.dump(mapper, meta, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_image_patches(reference_path, patch_size, patch_overlap, boundary_path, output_path, name_prefix=None):\n",
    "    \"\"\"\n",
    "    Function to create geospatial polygons that represent square image patch locations saved as a GeoJSON. The size of the image patches (assumed to be square) and the proportion of overlap between adjacent patches is specified. Each patch will have a unique id created from the patch_size, patch_overlap, and a unique number.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    reference_path : str\n",
    "        Path to a reference GeoTIFF image that represents the area where patches will be created.\n",
    "    patch_size : int or float\n",
    "        Size of the square patch in pixels.\n",
    "    patch_overlap : float\n",
    "        Proportion of overlap between adjacent patches.\n",
    "    boundary_path : str\n",
    "        Path to area of interest boundary GeoJSON file (should be aligned with boundaries of reference_path image) to ensure patch polygons intersect.\n",
    "    output_path : str\n",
    "        Path for output patch polygon GeoJSON file.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None.\n",
    "    \"\"\"\n",
    "\n",
    "    boundary = gpd.read_file(boundary_path)\n",
    "\n",
    "    with rasterio.open(reference_path) as src:\n",
    "        bounds = src.bounds\n",
    "        res = src.res[0]\n",
    "        crs = src.crs\n",
    "        \n",
    "    patch_size_units = patch_size * res\n",
    "    overlap_start_units = patch_size_units * (1 - patch_overlap)\n",
    "\n",
    "    patches = []\n",
    "    x = bounds.left\n",
    "    while x < bounds.right:\n",
    "        y = bounds.bottom\n",
    "        while y < bounds.top:\n",
    "            patch = box(x, y, x+patch_size_units, y+patch_size_units)\n",
    "\n",
    "            if patch.within(boundary.geometry).any():\n",
    "                patches.append(patch)\n",
    "            y += overlap_start_units\n",
    "        x += overlap_start_units\n",
    "    \n",
    "    gdf = gpd.GeoDataFrame(geometry=patches, crs=crs)\n",
    "\n",
    "    if not name_prefix:\n",
    "        gdf['patch_id'] = [f\"{patch_size}_{int(patch_overlap*100)}_{i}\" for i in range(1, len(gdf)+1)]\n",
    "    else:\n",
    "        gdf['patch_id'] = [f\"{name_prefix}_{patch_size}_{int(patch_overlap*100)}_{i}\" for i in range(1, len(gdf)+1)]\n",
    "        \n",
    "    gdf.to_file(output_path, driver='GeoJSON')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **Image Manipulation Functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mosaic_image_tiles(tile_paths, output_path, band_number, resample=None):\n",
    "    \"\"\"\n",
    "    Function to create a new single GeoTIFF mosaic from multiple smaller image tiles.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    tile_paths : str\n",
    "        List of paths to GeoTIFF tiles.\n",
    "    output_path : str\n",
    "        Path for new output mosaic GeoTIFF.\n",
    "    band_number : int\n",
    "        Band (channel) to mosaic.\n",
    "    resample : int (optional)\n",
    "        Resolution of output image. If not provided, output image will have the same resolution as input image tiles.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    images = [rasterio.open(tile_path) for tile_path in tile_paths]\n",
    "\n",
    "    if resample:\n",
    "        mosaic, mosaic_transform = merge(images, indexes=[band_number], res=resample, resampling=Resampling.bilinear)\n",
    "    else:\n",
    "        mosaic, mosaic_transform = merge(images, indexes=[band_number])\n",
    "\n",
    "    mosaic_meta = images[0].meta.copy()\n",
    "    mosaic_meta.update({'driver': 'GTiff', \n",
    "                        'height': mosaic.shape[1], \n",
    "                        'width': mosaic.shape[2], \n",
    "                        'transform': mosaic_transform, \n",
    "                        'crs': images[0].crs, \n",
    "                        'count': mosaic.shape[0]})\n",
    "    with rasterio.open(output_path, 'w', **mosaic_meta) as output:\n",
    "        for i in range(mosaic.shape[0]):\n",
    "            output.write(mosaic[i, :, :], i+1)\n",
    "    for src in images:\n",
    "        src.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def resample_image(input_path, new_resolution, output_path):\n",
    "    \"\"\"\n",
    "    Function to resample a GeoTIFF image to a new resolution and save as a new GeoTIFF.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    input_path : str\n",
    "        Path to the input GeoTIFF image to be resampled.\n",
    "    new_resolution : int or float\n",
    "        Resolution for the new, resampled image.\n",
    "    output_path : str\n",
    "        Path for the new, resampled GeoTIFF image.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "\n",
    "    with rasterio.open(input_path) as src:\n",
    "\n",
    "        # calculate the new transform and dimensions based on the new resolution\n",
    "        dst_transform, dst_width, dst_height = calculate_default_transform(src.crs,      # source CRS\n",
    "                                                                           src.crs,      # destination CRS\n",
    "                                                                           src.width,    # source width\n",
    "                                                                           src.height,   # source height\n",
    "                                                                           *src.bounds,  # source left, bottom, right, top coordinates \n",
    "                                                                           resolution=new_resolution)     # destination resolution\n",
    "        \n",
    "        # create metadata for new resampled image\n",
    "        dst_meta = src.meta.copy()\n",
    "        dst_meta.update({'driver': 'GTiff', \n",
    "                         'width': dst_width, \n",
    "                         'height': dst_height, \n",
    "                         'transform': dst_transform})\n",
    "        \n",
    "        # write new image to file with new transform & metadata & resolution\n",
    "        with rasterio.open(output_path, 'w', **dst_meta) as dst:\n",
    "            reproject(source=rasterio.band(src, 1), \n",
    "                      destination=rasterio.band(dst, 1), \n",
    "                      src_transform=src.transform, \n",
    "                      src_crs=src.crs, \n",
    "                      dst_transform=dst_transform, \n",
    "                      dst_crs=src.crs, \n",
    "                      resampling=Resampling.cubic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_image(input_path, sigma):\n",
    "    \"\"\"\n",
    "    Function to apply a Gaussian filter to an input image. See scipy.ndimage.gaussin_filter for more information regarding filter.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    input_path : str\n",
    "        Path to input image.\n",
    "    sigma : int, float\n",
    "        Standard deviation for Gaussian function.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "\n",
    "    with rasterio.open(input_path) as src:\n",
    "        data = src.read(1, masked=True)\n",
    "        dst_data = gaussian_filter(input=data, sigma=sigma)\n",
    "        dst_meta = src.meta.copy()\n",
    "    \n",
    "    output_path = input_path\n",
    "\n",
    "    with rasterio.open(output_path, 'w', **dst_meta) as dst:\n",
    "        dst.write(dst_data, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def image_to_reference_image(input_path, reference_path, output_path=None):\n",
    "    \"\"\"\n",
    "    Function to register and align an input image to a reference image then save the new aligned GeoTIFF. If the output path is not provided, the original input image is overwritten.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    input_path : str\n",
    "        Path to input image to be reprojected and aligned.\n",
    "    reference_path : str\n",
    "        Path to reference image to match alignment.\n",
    "    output_path : str (optional)\n",
    "        Path for output GeoTIFF. If not provided, the input image is overwritten.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "\n",
    "    with rasterio.open(input_path) as src:\n",
    "        src_profile = src.profile\n",
    "        src_data = src.read(1)\n",
    "\n",
    "    with rasterio.open(reference_path) as ref:\n",
    "        ref_profile = ref.profile\n",
    "        ref_data = ref.read(1, masked=True)\n",
    "    \n",
    "    dst_data = np.empty_like(ref_data)\n",
    "\n",
    "    reproject(source=src_data, \n",
    "              destination=dst_data, \n",
    "              src_transform=src_profile['transform'], \n",
    "              src_crs=src_profile['crs'], \n",
    "              dst_transform=ref_profile['transform'], \n",
    "              dst_crs=ref_profile['crs'], \n",
    "              dst_res=ref_profile['transform'][0], \n",
    "              resampling=Resampling.bilinear)\n",
    "\n",
    "    dst_meta = ref.meta.copy()\n",
    "\n",
    "    if not output_path:\n",
    "        output_path = input_path\n",
    "\n",
    "    with rasterio.open(output_path, 'w', **dst_meta) as dst:\n",
    "        dst.write(dst_data, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_patch(image_path, patches_gdf, output_dir):\n",
    "    \"\"\"\n",
    "    Function to use extract image patches from a geodataframe of patch polygyons.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    image_path : str\n",
    "        Path to image to extract patch.\n",
    "    patches_gdf : geodataframe\n",
    "        Geodataframe of patch polygons.\n",
    "    output_dir : str\n",
    "        Path for output image patch. Unique patch id from geodataframe will be used for prefix filename.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    image_name = os.path.basename(image_path)\n",
    "    image_name = os.path.splitext(image_name)[0]\n",
    "\n",
    "    with rasterio.open(image_path) as src:\n",
    "\n",
    "        for _, row in patches_gdf.iterrows():\n",
    "\n",
    "            geom = row['geometry']\n",
    "\n",
    "            dst_image, dst_transform = mask(src, shapes=[geom], crop=True, filled=True, nodata=-999999)\n",
    "\n",
    "            dst_meta = src.meta.copy()\n",
    "            dst_meta.update({'driver':'GTiff', \n",
    "                             'height':dst_image.shape[1], \n",
    "                             'width':dst_image.shape[2], \n",
    "                             'transform':dst_transform})\n",
    "        \n",
    "            output_path = f\"{output_dir}/{row['patch_id']}_{image_name}.tif\"\n",
    "    \n",
    "            with rasterio.open(output_path, 'w', **dst_meta) as dst:\n",
    "                dst.write(dst_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **Plotting & Data Check Functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_multi_terrain_features(mdhs_path, terrain_paths, bounds, cmap, title):\n",
    "    \"\"\"\n",
    "    Function to plot six terrain features from the same defined area. Terrain features have 50% transparency overlaying a multi-directional hillshade image.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    mdhs_path : str\n",
    "        Path to multi-directional hillshade GeoTIFF.\n",
    "    terrain_paths : iterable\n",
    "        List or tuple of paths terrain features at multiple resolutions\n",
    "    bounds : iterable\n",
    "        List or tuple of bounding coordinates (left, bottom, right, top) of area of interest.\n",
    "    cmap : str or variable\n",
    "        Name of Matplotlib colormap or custom colormap.\n",
    "    title : str\n",
    "        Title of terrain feature plot.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None.\n",
    "    \"\"\"\n",
    "\n",
    "    # set up plot assuming six scales/terrain features\n",
    "    fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(12,8), sharex=True, sharey=True)\n",
    "    fig.subplots_adjust(wspace=0.1, hspace=0.1)\n",
    "    ax = ax.ravel()\n",
    "\n",
    "    with rasterio.open(mdhs_path) as mdhs:\n",
    "\n",
    "        # iterate through each terrain feature (six total)\n",
    "        for idx, path in enumerate(terrain_paths):\n",
    "            with rasterio.open(path) as src:\n",
    "\n",
    "                # set up window for feature, get transform, and data\n",
    "                window = from_bounds(*bounds, src.transform)\n",
    "                transform = src.window_transform(window)\n",
    "                data = src.read(1, window=window)\n",
    "                min_val = np.min(data)\n",
    "                max_val = np.max(data)\n",
    "\n",
    "                # plot feature; this will be hidden and is only for colorbar\n",
    "                hidden = ax[idx].imshow(data, cmap=cmap)\n",
    "\n",
    "                # plot multi-directional hillshade as base layer (on top of hidden)\n",
    "                mdhs_window = from_bounds(*bounds, mdhs.transform)\n",
    "                mdhs_data = mdhs.read(1, window=mdhs_window)\n",
    "                mdhs_transform = mdhs.window_transform(mdhs_window)\n",
    "                show(mdhs_data, ax=ax[idx], cmap='binary_r', transform=mdhs_transform)\n",
    "\n",
    "                # plot terrain feature with transparency (to overlay on hillshade)\n",
    "                show(data, ax=ax[idx], cmap=cmap, transform=transform, alpha=0.5)\n",
    "\n",
    "                # plot custom color bar\n",
    "                cax = inset_axes(ax[idx], width='5%', height='40%', loc='lower right')\n",
    "                fig.colorbar(hidden, cax=cax, ticks=[min_val, max_val])\n",
    "                cax.yaxis.set_ticks_position('left')\n",
    "\n",
    "                # customize plot elements\n",
    "                ax[idx].tick_params(axis='both', which='major', labelsize=8)\n",
    "                ax[idx].tick_params(axis='x', labelrotation=60)\n",
    "                ax[idx].ticklabel_format(style='plain')\n",
    "                ax[idx].set_title(os.path.basename(path), style='italic', fontsize=10)\n",
    "\n",
    "    plt.suptitle(title, y=0.96)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_image_alignment(input_paths, target='geology'):\n",
    "    \"\"\"\n",
    "    Function to check alignment and registration of images in regards to the target image.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    input_paths : list, tuple\n",
    "        List or tuple of paths to images to check for alignment, including target image.\n",
    "    target : str\n",
    "        Name of target image that all other images should be aligned to.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    Dataframe of image names, paths, and alignment metrics.\n",
    "    \"\"\"\n",
    "    \n",
    "    # names of images\n",
    "    image_names = []\n",
    "    for path in input_paths:\n",
    "        name = os.path.basename(path)\n",
    "        name = os.path.splitext(name)[0]\n",
    "        image_names.append(name)\n",
    "    \n",
    "    # initialize new dataframe with names and paths and columns associated with alignment\n",
    "    df = pd.DataFrame({'image':image_names, 'path':input_paths})\n",
    "    df[['dtype', 'aligned', 'resolution_x', 'resolution_y', 'width', 'height', 'left', 'bottom', 'right', 'top']] = pd.NA\n",
    "\n",
    "    # iterate through image paths and get values\n",
    "    for image, path in zip(image_names, input_paths):\n",
    "        with rasterio.open(path) as src:\n",
    "            df.loc[df['image'] == image, 'dtype'] = src.meta['dtype']\n",
    "            df.loc[df['image'] == image, 'resolution_x'] = src.res[0]\n",
    "            df.loc[df['image'] == image, 'resolution_y'] = src.res[1]\n",
    "            df.loc[df['image'] == image, 'width'] = src.width\n",
    "            df.loc[df['image'] == image, 'height'] = src.height\n",
    "            df.loc[df['image'] == image, 'left'] = src.bounds[0]\n",
    "            df.loc[df['image'] == image, 'bottom'] = src.bounds[1]\n",
    "            df.loc[df['image'] == image, 'right'] = src.bounds[2]\n",
    "            df.loc[df['image'] == image, 'top'] = src.bounds[3]\n",
    "\n",
    "    # get array of values from target\n",
    "    target_alignment = df.loc[df['image']==target, 'resolution_x':].values\n",
    "\n",
    "    # check if other images are aligned to target\n",
    "    df['aligned'] = (df.loc[:, 'resolution_x':]==target_alignment).all(axis=1)\n",
    "    \n",
    "    return df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cs612",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
