{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "310c7260-d358-4923-8c3e-4db6387efde4",
   "metadata": {},
   "source": [
    "### Evaluation notebook\n",
    "\n",
    "Evaluates on `HPatches` dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9698474e-5c25-4b26-8b9b-0dd75962045b",
   "metadata": {},
   "source": [
    "For an R2D2-like model, given that you have run the inference script `relfm/inference/r2d2_on_hpatches.py` that generates outputs, you can run this notebook to test rotation equivariance of local feature matching."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4a93619-b0c5-484d-83b7-b8f8fddf2100",
   "metadata": {},
   "source": [
    "## Table of Contents\n",
    "\n",
    "* [Imports](#imports)\n",
    "* [Configure inputs](#configure_inputs)\n",
    "* [Generate results](#generate_results)\n",
    "* [Plot results](#plot_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e576b585-19cc-48e9-b513-29893e070efa",
   "metadata": {},
   "source": [
    "> *Warning*: This notebook takes about 9-10 minutes for generating results per model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90f15c3f-9249-443c-9803-34a9248ff7c9",
   "metadata": {},
   "source": [
    "### Imports <a class=\"anchor\" id=\"imports\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc4d6a9a-86b9-49ca-a8b0-18e9baf3ef60",
   "metadata": {},
   "source": [
    "Basic imports.\n",
    "> Note that you should set `PYTHONPATH=/path/to/repo/:/path/to/repo/lib/r2d2/` before running the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ec7fe25-756a-4329-b52e-c8e9ceedd3e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afdb41fb-93e1-4dab-b692-f91e1f19bc5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join, exists, expanduser\n",
    "from genericpath import isdir\n",
    "from glob import glob\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "from lib.r2d2.extract import extract_keypoints_modified\n",
    "from relfm.utils.paths import REPO_PATH\n",
    "from relfm.utils.log import print_update, tqdm_iterator\n",
    "from relfm.utils.visualize import show_images_with_keypoints, check_kps_with_homography\n",
    "from relfm.utils.matching import evaluate_matching_with_rotation, analyze_result\n",
    "from relfm.utils.geometry import (\n",
    "    append_rotation_to_homography, apply_homography_to_keypoints, resize, apply_clean_rotation,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6375b41-76c1-4ae4-99cc-893da92774b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_latex_fonts(show_sample=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c100f01-3c0f-47b0-9ba6-0ca5107eb890",
   "metadata": {},
   "source": [
    "### Configure inputs <a class=\"anchor\" id=\"configure_inputs\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dbc334b-32ee-446d-899f-3d9130a88926",
   "metadata": {},
   "source": [
    "Set the correct data and output paths as well as the model checkpoint that you'd like to evaluate on. Parameters such as `gap_between_rotations` and `imsize` (size of downsized image to evaluate) are hard-coded to be the same as the default values in the inference script."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "618ed8d3-8a5c-4e3a-bc99-4978d95d2828",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = join(REPO_PATH, \"data/hpatches-sequences-release/\")\n",
    "\n",
    "model_ckpt_paths = {\n",
    "    # \"R2D2\": join(REPO_PATH, \"checkpoints/r2d2_WASF_N16.pt\"),\n",
    "    # \"R2D2 - $C_{4}$\": join(REPO_PATH, \"trained_models/epoch_16_test_model.pt\"),\n",
    "    \"R2D2 - $SO_{2}$ (Ep 4)\": join(REPO_PATH, \"trained_models/epoch_3_SO2_4x16_1x32_1x64_2x128.pt\"),\n",
    "    # \"R2D2 - $C_{8}$ (Ep 4)\": join(REPO_PATH, \"trained_models/epoch_3_C8_4x16_1x32_1x64_2x128.pt\"),\n",
    "    # \"R2D2 - $SO_{2}$ (Ep 18)\": join(REPO_PATH, \"trained_models/epoch_17_SO2_4x16_1x32_1x64_2x128.pt\"),\n",
    "}\n",
    "\n",
    "output_dir = join(expanduser(\"~\"), \"outputs/rotation-equivariant-lfm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3ecf5f8-89af-4319-968c-d32e1612ebc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert isdir(data_dir)\n",
    "assert isdir(output_dir)\n",
    "# assert exists(model_ckpt_path)\n",
    "\n",
    "gap_between_rotations=15\n",
    "downsize=True\n",
    "imsize=300\n",
    "\n",
    "ignore_cache = True\n",
    "overwrite_cache = False\n",
    "\n",
    "ransac = True\n",
    "ransac_threshold = 3.\n",
    "\n",
    "sanity_check = False\n",
    "crop_post_rotation = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b67dd26-7d8e-4ab1-a63c-00f092d9d160",
   "metadata": {},
   "source": [
    "### Generate results <a class=\"anchor\" id=\"generate_results\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e4f521-ac39-4fe1-841d-4c694fdcf1db",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = dict()\n",
    "\n",
    "# load image sequences\n",
    "sequences = sorted(glob(join(data_dir, \"*\")))\n",
    "\n",
    "# load rotation values between 0 and 360 degrees\n",
    "rotations = np.arange(0, 360 + 1, gap_between_rotations, dtype=int)\n",
    "\n",
    "# set a (list of) pixel threshold value across which to evaluate rotation robustness\n",
    "thresholds = [3.]\n",
    "\n",
    "# set this to true to see intermediate outputs/messages\n",
    "verbose = False\n",
    "\n",
    "for model_name, model_ckpt_path in model_ckpt_paths.items():\n",
    "\n",
    "    # save directory\n",
    "    save_dir = configure_save_dir(output_dir, model_ckpt_path, dataset_name=\"hpatches\")\n",
    "\n",
    "    metrics_path = join(save_dir, \"metrics.pt\")\n",
    "    if exists(metrics_path) and not ignore_cache:\n",
    "        print_update(f\"Loading cached results for {model_name}\")\n",
    "        results[model_name] = torch.load(metrics_path)[\"MMA\"]\n",
    "        continue\n",
    "    \n",
    "    print_update(f\"Generating results for {model_name}\")\n",
    "\n",
    "    # computing Mean Matching Accuracy (MMA)\n",
    "    mma = defaultdict(list)\n",
    "    counter = 1\n",
    "    for sequence in sequences:\n",
    "\n",
    "        # store sequence name\n",
    "        sequence_name = os.path.basename(sequence)\n",
    "\n",
    "        # load source image\n",
    "        img1_path = join(sequence, \"1.ppm\")\n",
    "        img1_raw = Image.open(img1_path)\n",
    "\n",
    "        # define base homography for source image\n",
    "        H1_raw = np.eye(3)\n",
    "\n",
    "        # possible indices of the target images\n",
    "        img2_indices = np.arange(2, 7)\n",
    "\n",
    "        # load all target images at once\n",
    "        img2s_raw = [Image.open(join(sequence, f\"{i}.ppm\")) for i in img2_indices]\n",
    "\n",
    "        # define base homography for each target image\n",
    "        H2s_raw = [np.eye(3) for _ in img2s_raw]\n",
    "        \n",
    "        if downsize:\n",
    "            # downsize the source image to (args.imsize, args.imsize)\n",
    "            img1_resized, H1_raw = resize(img1_raw, imsize, imsize)\n",
    "\n",
    "            # downsize the target images to (args.imsize, args.imsize)\n",
    "            img2s_resized = []\n",
    "            for j in range(len(img2s_raw)):\n",
    "                img, H2s_raw[j] = resize(img2s_raw[j], imsize, imsize)\n",
    "                img2s_resized.append(img)\n",
    "        else:\n",
    "            img1_resized = img1_raw\n",
    "            img2s_resized = img2s_raw\n",
    "            \n",
    "        # load all homographies\n",
    "        H1to2s = [np.loadtxt(join(sequence, f\"H_1_{i}\")) for i in img2_indices]\n",
    "\n",
    "        rotation_grid, img2_indices_grid  = np.meshgrid(rotations, img2_indices)\n",
    "        rotation_grid, img2_indices_grid = rotation_grid.flatten(), img2_indices_grid.flatten()\n",
    "        iterator = tqdm_iterator(\n",
    "            range(len(rotation_grid)),\n",
    "            desc=f\"Evaluating predictions for {sequence_name} ({counter}/{len(sequences)})\\t\\t\",\n",
    "        )\n",
    "        for i in iterator:\n",
    "            rotation, img2_index = rotation_grid[i], img2_indices_grid[i]\n",
    "\n",
    "            img1 = img1_resized.copy()\n",
    "\n",
    "            # index target image\n",
    "            img2 = img2s_resized[img2_index - 2].copy()\n",
    "\n",
    "            # load base homography for source and target image\n",
    "            H1 = H1_raw.copy()\n",
    "            H2 = H2s_raw[img2_index - 2].copy()\n",
    "\n",
    "            if crop_post_rotation:\n",
    "                # center crop the source image according to the rotation\n",
    "                # NOTE: this does not rotate the image, only crops based on rotation\n",
    "                _, _, img1_transformed, H1 = apply_clean_rotation(\n",
    "                    image=img1, degrees=rotation, H=H1\n",
    "                )\n",
    "\n",
    "                # rotate + center crop the target image\n",
    "                # NOTE: this applies rotation and then cropping\n",
    "                img2_transformed, H2, _, _ = apply_clean_rotation(\n",
    "                    image=img2, degrees=rotation, H=H2,\n",
    "                )\n",
    "            else:\n",
    "                img1_transformed = img1\n",
    "                img2_transformed = img2.rotate(rotation)\n",
    "                H2 = append_rotation_to_homography(H2, rotation, img1.size[0], img1.size[1])\n",
    "            \n",
    "            # transform the homography accordingly\n",
    "            H = H1to2s[img2_index - 2].copy()\n",
    "            H_transformed = H2 @ H @ np.linalg.inv(H1)\n",
    "\n",
    "            # load outputs for source image\n",
    "            save_path = join(\n",
    "                save_dir, sequence_name, f\"1_rotation_{rotation}.npy\",\n",
    "            )\n",
    "            img1_outputs = np.load(save_path, allow_pickle=True).item()\n",
    "            \n",
    "            # load outputs for target image\n",
    "            save_path = join(\n",
    "                save_dir, sequence_name, f\"{img2_index}_rotation_{rotation}.npy\",\n",
    "            )\n",
    "            img2_outputs = np.load(save_path, allow_pickle=True).item()\n",
    "            \n",
    "            # load H\n",
    "            # H1 = img1_outputs[\"H1\"]\n",
    "            # H2 = img2_outputs[\"H2\"]\n",
    "            # H = H1to2s[img2_index - 2].copy()\n",
    "            # H_transformed = H2 @ H @ np.linalg.inv(H1)\n",
    "            \n",
    "            # get keypoints and descriptors from the outputs\n",
    "            kps1 = img1_outputs[\"keypoints\"]\n",
    "            des1 = img1_outputs[\"descriptors\"]\n",
    "\n",
    "            kps2 = img2_outputs[\"keypoints\"]\n",
    "            des2 = img2_outputs[\"descriptors\"]\n",
    "\n",
    "            # show detected keypoints\n",
    "            if verbose and rotation == 30 and img2_index == 4:\n",
    "                show_images_with_keypoints(\n",
    "                    [img1_transformed, img2_transformed],\n",
    "                    [kps1, kps2],\n",
    "                    radius=2,\n",
    "                )\n",
    "\n",
    "            # perform matching\n",
    "            width, height = img2_transformed.size\n",
    "            result = evaluate_matching_with_rotation(\n",
    "                kp1=kps1,\n",
    "                des1=des1,\n",
    "                kp2=kps2,\n",
    "                des2=des2,\n",
    "                H=H_transformed,\n",
    "                width=width,\n",
    "                height=height,\n",
    "                rotation=0,\n",
    "                return_metadata=True,\n",
    "                threshold=3.,\n",
    "                ransac=ransac,\n",
    "                ransac_threshold=ransac_threshold,\n",
    "            )\n",
    "\n",
    "            # show matching results\n",
    "            if verbose and rotation == 30 and img2_index == 4:\n",
    "                analyze_result(\n",
    "                    img1_transformed, img2_transformed, result, K=10, radius=5,\n",
    "                )\n",
    "\n",
    "            # compute accuracy across various thresholds\n",
    "            _match_accu = []\n",
    "            for threshold in thresholds:\n",
    "                _match_accu.append(np.mean(result[\"distances\"] < threshold))\n",
    "\n",
    "            mma[rotation].append(np.mean(_match_accu))\n",
    "\n",
    "        counter += 1\n",
    "\n",
    "        if verbose:\n",
    "            break\n",
    "\n",
    "    # compute the mean matching accuracy (MMA) for every rotation value\n",
    "    mma_avg = {k:np.array(mma[k]).mean() for k in mma}\n",
    "    \n",
    "    if overwrite_cache:\n",
    "        # save metrics\n",
    "        metrics = {\n",
    "            \"MMA\": mma_avg,\n",
    "        }\n",
    "        torch.save(metrics, metrics_path)\n",
    "    \n",
    "    # collect results\n",
    "    results[model_name] = mma_avg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61c260fb-ffc8-42ce-9f92-d18a74ce1080",
   "metadata": {},
   "source": [
    "### Plot results <a class=\"anchor\" id=\"plot_results\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884f8dbf-5b6f-4f08-8032-5ca588360914",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(14, 8))\n",
    "\n",
    "ax.grid(alpha=0.5)\n",
    "# ax.set_ylim((0., 1.))\n",
    "ransac_suffix = \" (with RANSAC) \" if ransac else \"\"\n",
    "ax.set_title(f\"Rotation-equivariance on HPatches dataset {ransac_suffix}\", fontsize=23)\n",
    "ax.set_xlabel(\"Rotation angle \", fontsize=18)\n",
    "ax.set_ylabel(\"Mean matching accuracy (MMA)\", fontsize=18)\n",
    "\n",
    "i = 0\n",
    "colors = get_colors(num_colors=len(results), palette=\"terrain\")\n",
    "colors = [\"blue\", \"gold\", \"#b355ed\", \"green\", \"red\"]\n",
    "markers = [\"o\", \"D\", \"^\", \"X\", \"s\"]\n",
    "linestyles = [\"solid\", \"dashed\", \"dashed\", \"dashed\", \"solid\"]\n",
    "fillstyles = [\"full\", \"none\", \"none\", \"none\", \"full\"]\n",
    "for model_name, mma_avg in results.items():\n",
    "    \n",
    "    if i == 2:\n",
    "        i += 1\n",
    "        continue\n",
    "\n",
    "    ax.plot(\n",
    "        list(mma_avg.keys()),\n",
    "        list(mma_avg.values()),\n",
    "        label=model_name.split(\" (\")[0],\n",
    "        markersize=8,\n",
    "        linewidth=2.,\n",
    "        color=colors[i],\n",
    "        marker=markers[i],\n",
    "        linestyle=linestyles[i],\n",
    "        fillstyle=fillstyles[i],\n",
    "    )\n",
    "    i += 1\n",
    "    ax.set_xticks(list(mma_avg.keys()))\n",
    "\n",
    "ax.legend(fontsize=17, bbox_to_anchor=(1., 0.95), title=\"Method\", title_fontsize=18)\n",
    "ax.tick_params(axis='both', which='major', labelsize=14)\n",
    "\n",
    "os.makedirs(\"../Figures/\", exist_ok=True)\n",
    "plt.savefig(f\"../Figures/mma_hpatches_without_RANSAC-v1.0.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a361479-b0f4-444f-ad6f-a2ca078b3b3b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
