{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "310c7260-d358-4923-8c3e-4db6387efde4",
   "metadata": {},
   "source": [
    "### Evaluation notebook\n",
    "\n",
    "Evaluates on `HPatches` dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9698474e-5c25-4b26-8b9b-0dd75962045b",
   "metadata": {},
   "source": [
    "For an R2D2-like model, given that you have run the inference script `relfm/inference/r2d2_on_hpatches.py` that generates outputs, you can run this notebook to test rotation equivariance of local feature matching."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4a93619-b0c5-484d-83b7-b8f8fddf2100",
   "metadata": {},
   "source": [
    "## Table of Contents\n",
    "\n",
    "* [Imports](#imports)\n",
    "* [Configure inputs](#configure_inputs)\n",
    "* [Generate results](#generate_results)\n",
    "* [Plot results](#plot_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e576b585-19cc-48e9-b513-29893e070efa",
   "metadata": {},
   "source": [
    "> *Warning*: This notebook takes about 9-10 minutes for generating results per model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90f15c3f-9249-443c-9803-34a9248ff7c9",
   "metadata": {},
   "source": [
    "### Imports <a class=\"anchor\" id=\"imports\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc4d6a9a-86b9-49ca-a8b0-18e9baf3ef60",
   "metadata": {},
   "source": [
    "Basic imports.\n",
    "> Note that you should set `PYTHONPATH=/path/to/repo/:/path/to/repo/lib/r2d2/` before running the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ec7fe25-756a-4329-b52e-c8e9ceedd3e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afdb41fb-93e1-4dab-b692-f91e1f19bc5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join, exists, expanduser\n",
    "from genericpath import isdir\n",
    "from glob import glob\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "from lib.r2d2.extract import extract_keypoints_modified\n",
    "from relfm.utils.paths import REPO_PATH\n",
    "from relfm.utils.log import print_update, tqdm_iterator\n",
    "from relfm.utils.visualize import show_images_with_keypoints, set_latex_fonts, get_colors\n",
    "from relfm.utils.matching import evaluate_matching_with_rotation, analyze_result\n",
    "from relfm.inference.r2d2_on_hpatches import configure_save_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6375b41-76c1-4ae4-99cc-893da92774b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_latex_fonts(show_sample=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c100f01-3c0f-47b0-9ba6-0ca5107eb890",
   "metadata": {},
   "source": [
    "### Configure inputs <a class=\"anchor\" id=\"configure_inputs\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dbc334b-32ee-446d-899f-3d9130a88926",
   "metadata": {},
   "source": [
    "Set the correct data and output paths as well as the model checkpoint that you'd like to evaluate on. Parameters such as `gap_between_rotations` and `imsize` (size of downsized image to evaluate) are hard-coded to be the same as the default values in the inference script."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "618ed8d3-8a5c-4e3a-bc99-4978d95d2828",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = join(REPO_PATH, \"data/hpatches-sequences-release/\")\n",
    "\n",
    "model_ckpt_paths = {\n",
    "    \"R2D2\": join(REPO_PATH, \"checkpoints/r2d2_WASF_N16.pt\"),\n",
    "    \"R2D2 - $C_{4}$\": join(REPO_PATH, \"trained_models/epoch_16_test_model.pt\"),\n",
    "    \"R2D2 - $SO_{2}$ (Ep 4)\": join(REPO_PATH, \"trained_models/epoch_3_SO2_4x16_1x32_1x64_2x128.pt\"),\n",
    "    \"R2D2 - $C_{8}$ (Ep 4)\": join(REPO_PATH, \"trained_models/epoch_3_C8_4x16_1x32_1x64_2x128.pt\"),\n",
    "    \"R2D2 - $SO_{2}$ (Ep 18)\": join(REPO_PATH, \"trained_models/epoch_17_SO2_4x16_1x32_1x64_2x128.pt\"),\n",
    "#     \"R2D2 - $SO_{2}$ (Ep 0)\": join(REPO_PATH, \"trained_models/epoch_0_SO2_downsamp_4x16_2x32_1x64_1x128.pt\"),\n",
    "}\n",
    "\n",
    "output_dir = join(expanduser(\"~\"), \"outputs/rotation-equivariant-lfm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3ecf5f8-89af-4319-968c-d32e1612ebc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert isdir(data_dir)\n",
    "assert isdir(output_dir)\n",
    "# assert exists(model_ckpt_path)\n",
    "\n",
    "gap_between_rotations=15\n",
    "downsize=True\n",
    "imsize=300\n",
    "\n",
    "ignore_cache = False\n",
    "overwrite_cache = False\n",
    "\n",
    "ransac = False\n",
    "ransac_threshold = 3."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b67dd26-7d8e-4ab1-a63c-00f092d9d160",
   "metadata": {},
   "source": [
    "### Generate results <a class=\"anchor\" id=\"generate_results\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e4f521-ac39-4fe1-841d-4c694fdcf1db",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = dict()\n",
    "\n",
    "# load image sequences\n",
    "sequences = sorted(glob(join(data_dir, \"*\")))\n",
    "\n",
    "# load rotation values between 0 and 360 degrees\n",
    "rotations = np.arange(0, 360 + 1, gap_between_rotations, dtype=int)\n",
    "\n",
    "# set a (list of) pixel threshold value across which to evaluate rotation robustness\n",
    "thresholds = [3.]\n",
    "\n",
    "# set this to true to see intermediate outputs/messages\n",
    "verbose = False\n",
    "\n",
    "for model_name, model_ckpt_path in model_ckpt_paths.items():\n",
    "\n",
    "    # save directory\n",
    "    save_dir = configure_save_dir(output_dir, model_ckpt_path, dataset_name=\"hpatches\")\n",
    "\n",
    "    metrics_path = join(save_dir, \"metrics.pt\")\n",
    "    if exists(metrics_path) and not ignore_cache:\n",
    "        print_update(f\"Loading cached results for {model_name}\")\n",
    "        results[model_name] = torch.load(metrics_path)[\"MMA\"]\n",
    "        continue\n",
    "    \n",
    "    print_update(f\"Generating results for {model_name}\")\n",
    "\n",
    "    # computing Mean Matching Accuracy (MMA)\n",
    "    mma = defaultdict(list)\n",
    "    counter = 1\n",
    "    for sequence in sequences:\n",
    "        # set path to the source image\n",
    "        img1_path = join(sequence, \"1.ppm\")\n",
    "        img1 = Image.open(img1_path)\n",
    "        if downsize:\n",
    "            img1 = img1.resize((imsize, imsize))\n",
    "\n",
    "        # load outputs for source image\n",
    "        sequence_name = os.path.basename(sequence)\n",
    "        save_path = join(save_dir, sequence_name, \"1.npy\")\n",
    "        img1_outputs = np.load(save_path, allow_pickle=True).item()\n",
    "\n",
    "        # possible indices of the target images\n",
    "        img2_indices = np.arange(2, 7)\n",
    "\n",
    "        # load all target images at once\n",
    "        img2s = [Image.open(join(sequence, f\"{i}.ppm\")) for i in img2_indices]\n",
    "        if downsize:\n",
    "            img2s = [img2.resize((imsize, imsize)) for img2 in img2s]\n",
    "\n",
    "        # load all homographies\n",
    "        # NOTE that this is not needed since we save the apt H within outputs itself\n",
    "\n",
    "        rotation_grid, img2_indices_grid  = np.meshgrid(rotations, img2_indices)\n",
    "        rotation_grid, img2_indices_grid = rotation_grid.flatten(), img2_indices_grid.flatten()\n",
    "\n",
    "        iterator = tqdm_iterator(\n",
    "            range(len(rotation_grid)),\n",
    "            desc=f\"Evaluating predictions for {sequence_name} ({counter}/{len(sequences)})\\t\\t\",\n",
    "        )\n",
    "        for i in iterator:\n",
    "            rotation, img2_index = rotation_grid[i], img2_indices_grid[i]\n",
    "\n",
    "            if verbose:\n",
    "                print(f\"Image {img2_index} with rotation {rotation}\")\n",
    "\n",
    "            img2 = img2s[img2_index - 2]\n",
    "            img2_rotated = img2.rotate(rotation)\n",
    "\n",
    "            save_path = join(save_dir, sequence_name, f\"{img2_index}_rotation_{rotation}.npy\")\n",
    "            img2_outputs = np.load(save_path, allow_pickle=True).item()\n",
    "\n",
    "            # get keypoints and descriptors from the outputs\n",
    "            kps1 = img1_outputs[\"keypoints\"]\n",
    "            des1 = img1_outputs[\"descriptors\"]\n",
    "\n",
    "            kps2 = img2_outputs[\"keypoints\"]\n",
    "            des2 = img2_outputs[\"descriptors\"]\n",
    "            H = img2_outputs[\"H\"]\n",
    "\n",
    "            # show detected keypoints\n",
    "            if verbose and rotation == 30 and img2_index == 4:\n",
    "                show_images_with_keypoints([img1, img2_rotated], [kps1, kps2], radius=2)\n",
    "\n",
    "            # perform matching\n",
    "            width, height = img2.size\n",
    "            result = evaluate_matching_with_rotation(\n",
    "                kp1=kps1,\n",
    "                des1=des1,\n",
    "                kp2=kps2,\n",
    "                des2=des2,\n",
    "                H=H,\n",
    "                width=width,\n",
    "                height=height,\n",
    "                rotation=rotation,\n",
    "                return_metadata=True,\n",
    "                threshold=3.,\n",
    "                ransac=ransac,\n",
    "                ransac_threshold=ransac_threshold,\n",
    "            )\n",
    "\n",
    "            # show matching results\n",
    "            if verbose and rotation == 30 and img2_index == 4:\n",
    "                analyze_result(img1, img2_rotated, result, K=10, radius=5)\n",
    "\n",
    "            # compute accuracy across various thresholds\n",
    "            _match_accu = []\n",
    "            for threshold in thresholds:\n",
    "                _match_accu.append(np.mean(result[\"distances\"] < threshold))\n",
    "\n",
    "            mma[rotation].append(np.mean(_match_accu))\n",
    "\n",
    "        counter += 1\n",
    "\n",
    "    # compute the mean matching accuracy (MMA) for every rotation value\n",
    "    mma_avg = {k:np.array(mma[k]).mean() for k in mma}\n",
    "    \n",
    "    if overwrite_cache:\n",
    "        # save metrics\n",
    "        metrics = {\n",
    "            \"MMA\": mma_avg,\n",
    "        }\n",
    "        torch.save(metrics, metrics_path)\n",
    "    \n",
    "    # collect results\n",
    "    results[model_name] = mma_avg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61c260fb-ffc8-42ce-9f92-d18a74ce1080",
   "metadata": {},
   "source": [
    "### Plot results <a class=\"anchor\" id=\"plot_results\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884f8dbf-5b6f-4f08-8032-5ca588360914",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(14, 8))\n",
    "\n",
    "ax.grid(alpha=0.5)\n",
    "# ax.set_ylim((0., 1.))\n",
    "ransac_suffix = \" (with RANSAC) \" if ransac else \"\"\n",
    "ax.set_title(f\"Rotation-equivariance on HPatches dataset {ransac_suffix}\", fontsize=23)\n",
    "ax.set_xlabel(\"Rotation angle \", fontsize=18)\n",
    "ax.set_ylabel(\"Mean matching accuracy (MMA)\", fontsize=18)\n",
    "\n",
    "i = 0\n",
    "colors = get_colors(num_colors=len(results), palette=\"terrain\")\n",
    "colors = [\"blue\", \"gold\", \"#b355ed\", \"green\", \"red\"]\n",
    "markers = [\"o\", \"D\", \"^\", \"X\", \"s\"]\n",
    "linestyles = [\"solid\", \"dashed\", \"dashed\", \"dashed\", \"solid\"]\n",
    "fillstyles = [\"full\", \"none\", \"none\", \"none\", \"full\"]\n",
    "for model_name, mma_avg in results.items():\n",
    "    \n",
    "    if i == 2:\n",
    "        i += 1\n",
    "        continue\n",
    "\n",
    "    ax.plot(\n",
    "        list(mma_avg.keys()),\n",
    "        list(mma_avg.values()),\n",
    "        label=model_name.split(\" (\")[0],\n",
    "        markersize=8,\n",
    "        linewidth=2.,\n",
    "        color=colors[i],\n",
    "        marker=markers[i],\n",
    "        linestyle=linestyles[i],\n",
    "        fillstyle=fillstyles[i],\n",
    "    )\n",
    "    i += 1\n",
    "    ax.set_xticks(list(mma_avg.keys()))\n",
    "\n",
    "ax.legend(fontsize=17, bbox_to_anchor=(1., 0.95), title=\"Method\", title_fontsize=18)\n",
    "ax.tick_params(axis='both', which='major', labelsize=14)\n",
    "\n",
    "os.makedirs(\"../Figures/\", exist_ok=True)\n",
    "plt.savefig(f\"../Figures/mma_hpatches_without_RANSAC-v1.0.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a361479-b0f4-444f-ad6f-a2ca078b3b3b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
