{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "de6f4cdb-8291-466d-a541-e4ed3a99241e",
   "metadata": {},
   "source": [
    "### Evaluation notebook\n",
    "\n",
    "Evaluates on `HPatches` dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f339400-4ffa-4890-bfc5-df20dd7b424c",
   "metadata": {},
   "source": [
    "For an R2D2-like model, given that you have run the inference script `relfm/inference/r2d2_on_hpatches.py` that generates outputs, you can run this notebook to test rotation equivariance of local feature matching."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "faf62a30-4bcb-4186-b362-7df05cc554d9",
   "metadata": {},
   "source": [
    "## Table of Contents\n",
    "\n",
    "* [Imports](#imports)\n",
    "* [Configure inputs](#configure_inputs)\n",
    "* [Generate results](#generate_results)\n",
    "* [Plot results](#plot_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92b0052e-597c-41d6-b86c-9f36de78ad3d",
   "metadata": {},
   "source": [
    "> *Warning*: This notebook takes about 9-10 minutes for generating results per model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ca8f834-efa6-465b-836a-5937bfc62423",
   "metadata": {},
   "source": [
    "### Imports <a class=\"anchor\" id=\"imports\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1724ddd4-2669-468c-bc49-014adfa8745d",
   "metadata": {},
   "source": [
    "Basic imports.\n",
    "> Note that you should set `PYTHONPATH=/path/to/repo/:/path/to/repo/lib/r2d2/` before running the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad580169-bb37-4d2c-8360-f9eaf2915fc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6767ced9-2cb2-42d4-af91-6cef4eb82be0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join, exists, expanduser, basename\n",
    "from genericpath import isdir\n",
    "from glob import glob\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "from lib.r2d2.extract import extract_keypoints_modified\n",
    "from relfm.utils.paths import REPO_PATH\n",
    "from relfm.utils.log import print_update, tqdm_iterator\n",
    "from relfm.utils.visualize import show_images_with_keypoints, set_latex_fonts, show_grid_of_images, get_concat_h\n",
    "from relfm.utils.matching import evaluate_matching_with_rotation, analyze_result\n",
    "from relfm.inference.r2d2_on_hpatches import configure_save_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1776b7f-9ba8-44ea-be59-d6a429dc5c43",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_latex_fonts(show_sample=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5a33052-1555-4cd1-b04c-50620befb29f",
   "metadata": {},
   "source": [
    "### Configure inputs <a class=\"anchor\" id=\"configure_inputs\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9284aaa7-cae1-483f-9bfa-c629115d29b5",
   "metadata": {},
   "source": [
    "Set the correct data and output paths as well as the model checkpoint that you'd like to evaluate on. Parameters such as `gap_between_rotations` and `imsize` (size of downsized image to evaluate) are hard-coded to be the same as the default values in the inference script."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44034c42-4394-4ac9-b73b-4a4801fc764f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = join(REPO_PATH, \"data/hpatches-sequences-release/\")\n",
    "\n",
    "model_ckpt_paths = {\n",
    "    # \"R2D2\": join(REPO_PATH, \"checkpoints/r2d2_WASF_N16.pt\"),\n",
    "    # \"R2D2 - $C_{4}$ (Ep 4)\": join(REPO_PATH, \"trained_models/epoch_16_test_model.pt\"),\n",
    "    # # \"R2D2 - $SO_{2}$ (Ep 4)\": join(REPO_PATH, \"trained_models/epoch_3_SO2_4x16_1x32_1x64_2x128.pt\"),\n",
    "    # \"R2D2 - $C_{8}$ (Ep 4)\": join(REPO_PATH, \"trained_models/epoch_3_C8_4x16_1x32_1x64_2x128.pt\"),\n",
    "    # \"R2D2 - $SO_{2}$ (Ep 18)\": join(REPO_PATH, \"trained_models/epoch_17_SO2_4x16_1x32_1x64_2x128.pt\")\n",
    "    \"R2D2\": join(REPO_PATH, \"trained_models/r2d2_WASF_N16.pt\"),\n",
    "    \"C-3PO - $C_{3}$\": join(REPO_PATH, \"trained_models/finalmodelC3_epoch_2_4x16_1x32_1x64_2x128.pt\"),\n",
    "    \"C-3PO - $C_{4}$\": join(REPO_PATH, \"trained_models/finalmodelC4_epoch_5_4x16_1x32_1x64_2x128.pt\"),\n",
    "    \"C-3PO - $C_{8}$\": join(REPO_PATH, \"trained_models/finalmodelC8_epoch_1_4x16_1x32_1x64_2x128.pt\"),\n",
    "    \"C-3PO - $SO(2)$\": join(REPO_PATH, \"trained_models/finalmodelSO2_epoch_17_4x16_1x32_1x64_2x128.pt\"),\n",
    "}\n",
    "\n",
    "output_dir = join(expanduser(\"~\"), \"outputs/rotation-equivariant-lfm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fda664a-c77a-4864-97a0-935e6de314fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert isdir(data_dir)\n",
    "assert isdir(output_dir)\n",
    "# assert exists(model_ckpt_path)\n",
    "\n",
    "gap_between_rotations=15\n",
    "downsize=True\n",
    "imsize=300\n",
    "\n",
    "ignore_cache = True\n",
    "overwrite_cache = False\n",
    "\n",
    "ransac = True\n",
    "ransac_threshold = 3."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edf131c1-a9e9-44ea-9f71-956522c223a2",
   "metadata": {},
   "source": [
    "### Generate results <a class=\"anchor\" id=\"generate_results\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6791a528-7575-4f85-abc1-8cc0ce0d69a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load image sequences\n",
    "sequences = sorted(glob(join(data_dir, \"*\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7b5226e-7cdc-49f1-ba44-ed87858bf1c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "[basename(x) for x in sequences].index(\"i_castle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d1b34a8-2c87-4098-a67b-07e5b2be51b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# results = dict()\n",
    "\n",
    "\n",
    "# # select a random sequence to plot results\n",
    "# sequence = np.random.choice(sequences)\n",
    "\n",
    "# choose a sequence to visualize\n",
    "sequence = sequences[7]\n",
    "\n",
    "# load rotation values between 0 and 360 degrees\n",
    "rotations = np.arange(0, 360 + 1, gap_between_rotations, dtype=int)\n",
    "rotation_to_visualize = 45\n",
    "\n",
    "# set a (list of) pixel threshold value across which to evaluate rotation robustness\n",
    "thresholds = [10.]\n",
    "\n",
    "# set this to true to see intermediate outputs/messages\n",
    "verbose = True\n",
    "\n",
    "# select specific target image index to visualize\n",
    "img2_index = 4\n",
    "# select specific rotation value to visualize\n",
    "rotation = 75\n",
    "\n",
    "source_images_with_kps = dict()\n",
    "target_images_with_kps = dict()\n",
    "matched_images = dict()\n",
    "\n",
    "for model_name, model_ckpt_path in model_ckpt_paths.items():\n",
    "\n",
    "    # save directory\n",
    "    save_dir = configure_save_dir(output_dir, model_ckpt_path, dataset_name=\"hpatches\")\n",
    "\n",
    "    print_update(f\"Generating results for {model_name} for sequence {basename(sequence)}\")\n",
    "\n",
    "    # set path to the source image\n",
    "    img1_path = join(sequence, \"1.ppm\")\n",
    "    img1 = Image.open(img1_path)\n",
    "    if downsize:\n",
    "        img1 = img1.resize((imsize, imsize))\n",
    "\n",
    "    # load outputs for source image\n",
    "    sequence_name = os.path.basename(sequence)\n",
    "    save_path = join(save_dir, sequence_name, \"1_rotation_0.npy\")\n",
    "    img1_outputs = np.load(save_path, allow_pickle=True).item()\n",
    "\n",
    "    # possible indices of the target images\n",
    "    img2_indices = np.arange(2, 7)\n",
    "\n",
    "    # load all target images at once\n",
    "    img2s = [Image.open(join(sequence, f\"{i}.ppm\")) for i in img2_indices]\n",
    "    if downsize:\n",
    "        img2s = [img2.resize((imsize, imsize)) for img2 in img2s]\n",
    "\n",
    "    # load all homographies\n",
    "    # NOTE that this is not needed since we save the apt H within outputs itself\n",
    "\n",
    "    rotation_grid, img2_indices_grid  = np.meshgrid(rotations, img2_indices)\n",
    "    rotation_grid, img2_indices_grid = rotation_grid.flatten(), img2_indices_grid.flatten()\n",
    "\n",
    "    # iterator = tqdm_iterator(\n",
    "    #     range(len(rotation_grid)),\n",
    "    #     desc=f\"Generating qualitative results for {sequence_name} \\t:\",\n",
    "    # )\n",
    "    # for i in iterator:\n",
    "    #     rotation, img2_index = rotation_grid[i], img2_indices_grid[i]\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"Image {img2_index} with rotation {rotation}\")\n",
    "\n",
    "    img2 = img2s[img2_index - 2]\n",
    "    img2_rotated = img2.rotate(rotation)\n",
    "\n",
    "    save_path = join(save_dir, sequence_name, f\"{img2_index}_rotation_{rotation}.npy\")\n",
    "    img2_outputs = np.load(save_path, allow_pickle=True).item()\n",
    "\n",
    "    # get keypoints and descriptors from the outputs\n",
    "    kps1 = img1_outputs[\"keypoints\"]\n",
    "    des1 = img1_outputs[\"descriptors\"]\n",
    "\n",
    "    kps2 = img2_outputs[\"keypoints\"]\n",
    "    des2 = img2_outputs[\"descriptors\"]\n",
    "    H = img2_outputs[\"H\"]\n",
    "\n",
    "    # show detected keypoints\n",
    "    if verbose:\n",
    "        images_with_kps = show_images_with_keypoints([img1, img2_rotated], [kps1, kps2], radius=2, return_images=True)\n",
    "        source_images_with_kps[model_name] = images_with_kps[0]\n",
    "        target_images_with_kps[model_name] = images_with_kps[-1]\n",
    "\n",
    "    # perform matching\n",
    "    width, height = img2.size\n",
    "    result = evaluate_matching_with_rotation(\n",
    "        kp1=kps1,\n",
    "        des1=des1,\n",
    "        kp2=kps2,\n",
    "        des2=des2,\n",
    "        H=H,\n",
    "        width=width,\n",
    "        height=height,\n",
    "        rotation=rotation,\n",
    "        return_metadata=True,\n",
    "        threshold=10.,\n",
    "    )\n",
    "\n",
    "    # show matching results\n",
    "    if verbose:\n",
    "        matched_image = analyze_result(\n",
    "            img1, img2_rotated, result,\n",
    "            K=50, radius=3, model_name=model_name, save_dir=\"../Figures/\",\n",
    "            return_img=True\n",
    "        )\n",
    "        matched_images[model_name] = matched_image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "954164a8-be4d-402c-baa1-c9cb2b1b580b",
   "metadata": {},
   "source": [
    "### Show keypoint detection in target images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "269ec972-9e33-4ab0-9f46-1f7d67859b0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(target_images_with_kps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83bd914e-c503-4d0e-9af0-e777a37ec07a",
   "metadata": {},
   "outputs": [],
   "source": [
    "images = list(target_images_with_kps.values())\n",
    "subtitles = list(target_images_with_kps.keys())\n",
    "subtitles = [x.split(\"(\")[0] for x in subtitles]\n",
    "\n",
    "use_titles = False\n",
    "if not use_titles:\n",
    "    subtitles = None\n",
    "\n",
    "show_grid_of_images(\n",
    "    images,\n",
    "    n_cols=len(model_ckpt_paths), figsize=(20, 5),\n",
    "    subtitles=subtitles,\n",
    "    subtitlesize=20,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc61d625-b772-4886-b081-42234e3846d0",
   "metadata": {},
   "source": [
    "### Show combined results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe175324-6285-424a-b6be-70a971670ab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(len(model_ckpt_paths), 1, figsize=(20, 20))\n",
    "\n",
    "for i, model_name in enumerate(list(model_ckpt_paths.keys())):\n",
    "    ax = axes[i]\n",
    "    source = Image.fromarray(source_images_with_kps[model_name])\n",
    "    target = Image.fromarray(target_images_with_kps[model_name])\n",
    "    matchd = matched_images[model_name]\n",
    "    img = get_concat_h(get_concat_h(source, target), matchd)\n",
    "    ax.imshow(np.asarray(img))\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    title = f\"{model_name} at {rotation}$^\\circ$\"\n",
    "    ax.set_title(title, fontsize=20)\n",
    "    \n",
    "fig.tight_layout()\n",
    "plt.savefig(f\"../Figures/qual_results_rotation_{sequence_name}_{rotation}.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f751646-f904-47da-8beb-28e07fb4507d",
   "metadata": {},
   "outputs": [],
   "source": [
    "Image.fromarray(source_images_with_kps[\"R2D2\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98339e14-acd1-4737-8d98-3a83779e9e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "Image.fromarray(target_images_with_kps[\"R2D2\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7899ff7b-e7b4-48b4-851a-3ff82c184b60",
   "metadata": {},
   "outputs": [],
   "source": [
    "matched_images[\"R2D2\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "918f61e4-f56b-4577-bc6d-c03ca0706220",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
