{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cafc7b69-b27d-40f0-986a-ac44043a1522",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce6708cd-01c1-447e-bb46-3d6a3f191b8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join, basename\n",
    "\n",
    "import urllib.request\n",
    "import numpy as np\n",
    "import cv2\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dacb7ff8-c867-4ed7-a786-d9582e34d5e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lib.r2d2.extract import extract_keypoints\n",
    "\n",
    "from relfm.utils.visualize import show_single_image, show_grid_of_images, show_keypoint_matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0921ffe-db16-4e39-aa31-1c394958a7fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_file(url, path):\n",
    "    \"\"\"Downloads file at URL `url` to local path `path`.\"\"\"\n",
    "    urllib.request.urlretrieve(url, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "190c446c-8b84-4281-8d4d-4c1940c2cee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# download model checkpoint\n",
    "\n",
    "URL = \"https://github.com/naver/r2d2/raw/master/models/r2d2_WASF_N16.pt\"\n",
    "\n",
    "model_dir = \"../checkpoints/\"\n",
    "os.makedirs(model_dir, exist_ok=True)\n",
    "\n",
    "path = join(model_dir, basename(URL))\n",
    "download_file(URL, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f3ba5c6-abff-4023-bcb8-955182a71368",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check downloaded file\n",
    "# ckpt = np.load(path, allow_pickle=True)\n",
    "ckpt = torch.load(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21da0ebc-8820-46d8-87b7-d42e2fac543a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2beff9cf-6ffa-455b-a226-11bef7901f5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "img1_path = \"../sample_images/house1.png\"\n",
    "img2_path = \"../sample_images/house2_rotated.png\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "123f4624-8722-4989-a356-c133c86c76c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = {\n",
    "    \"model\": path,\n",
    "    \"images\": [img1_path, img2_path],\n",
    "    \"tag\": \"r2d2\",\n",
    "    \"top_k\": 5000,\n",
    "    \"scale_f\": 2**0.25,\n",
    "    \"min_size\": 256,\n",
    "    \"max_size\": 1024,\n",
    "    \"min_scale\": 0,\n",
    "    \"max_scale\": 1,\n",
    "    \"reliability_thr\": 0.7,\n",
    "    \"repeatability_thr\": 0.7,\n",
    "    \"gpu\": -1,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed8e9999-d211-4705-a8c9-4b693d8468bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class dotdict(dict):\n",
    "    \"\"\"dot.notation access to dictionary attributes\"\"\"\n",
    "    __getattr__ = dict.get\n",
    "    __setattr__ = dict.__setitem__\n",
    "    __delattr__ = dict.__delitem__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1e3b887-b8fd-447f-847a-8c37bcf8d5d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = dotdict(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e3190e0-8eb3-4ec4-87fd-026652164981",
   "metadata": {},
   "outputs": [],
   "source": [
    "extract_keypoints(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0289320-d6ad-4285-ae4f-ba2dec9c0677",
   "metadata": {},
   "source": [
    "### Visualize matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb6c21b7-7bf5-4a04-ac43-dc660e6d5081",
   "metadata": {},
   "outputs": [],
   "source": [
    "img1 = cv2.cvtColor(cv2.imread(img1_path), cv2.COLOR_BGR2RGB)\n",
    "img2 = cv2.cvtColor(cv2.imread(img2_path), cv2.COLOR_BGR2RGB)\n",
    "\n",
    "img1_outputs = np.load(img1_path + \".\" + args.tag)\n",
    "img2_outputs = np.load(img2_path + \".\" + args.tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db92e8fe-ddc7-422a-bdcb-a2e65f186f3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_grid_of_images([img1, img2], n_cols=2, subtitles=[\"Image 1\", \"Image 2\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c62993-7ee0-4d4c-b6e4-200da24ae58e",
   "metadata": {},
   "outputs": [],
   "source": [
    "kp1 = [cv2.KeyPoint(*coord) for coord in img1_outputs[\"keypoints\"]]\n",
    "img1_with_kps = cv2.drawKeypoints(img1, kp1, np.array([]), (255,0,0))\n",
    "\n",
    "kp2 = [cv2.KeyPoint(*coord) for coord in img2_outputs[\"keypoints\"]]\n",
    "img2_with_kps = cv2.drawKeypoints(img2, kp2, np.array([]), (255,0,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0554931-369f-4b64-80ee-bd32627cfdef",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_grid_of_images([img1_with_kps, img2_with_kps], n_cols=2, subtitles=[\"Image 1\", \"Image 2\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e2a4aef-fed6-48d6-b339-b4729225f57e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mnn_matcher(descriptors_a, descriptors_b):\n",
    "    \"\"\"\n",
    "    Borrowed from: D2Net's HPatches benchmark notebook.\n",
    "    Link: https://github.com/mihaidusmanu/d2-net/blob/master/hpatches_sequences/HPatches-Sequences-Matching-Benchmark.ipynb\n",
    "    \"\"\"\n",
    "    device = descriptors_a.device\n",
    "    sim = descriptors_a @ descriptors_b.t()\n",
    "    nn12 = torch.max(sim, dim=1)[1]\n",
    "    nn21 = torch.max(sim, dim=0)[1]\n",
    "    ids1 = torch.arange(0, sim.shape[0], device=device)\n",
    "    mask = (ids1 == nn21[nn12])\n",
    "    matches = torch.stack([ids1[mask], nn12[mask]])\n",
    "    dist = 1- sim[[matches[0], matches[1]]]\n",
    "    return dist.data.cpu().numpy(), matches.t().data.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dc80695-26da-4138-9f73-94559c8e626b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find matches and show them\n",
    "\n",
    "# create BFMatcher object\n",
    "bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)\n",
    "\n",
    "# Match descriptors.\n",
    "des1 = img1_outputs[\"descriptors\"]\n",
    "des2 = img2_outputs[\"descriptors\"]\n",
    "dist, _matches = mnn_matcher(torch.from_numpy(des1), torch.from_numpy(des2))\n",
    "\n",
    "# sort based on distances\n",
    "order = np.argsort(dist)\n",
    "_matches = _matches[order]\n",
    "dist = dist[order]\n",
    "\n",
    "matches = []\n",
    "for (x, y), d in zip(_matches, np.round(dist, 1)):\n",
    "    matches.append(cv2.DMatch(_queryIdx=x, _trainIdx=y, _distance=d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbeabe5a-9b34-468c-a866-a6fef32f9b62",
   "metadata": {},
   "outputs": [],
   "source": [
    "matches_img = show_keypoint_matches(img1, kp1, img2, kp2, matches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e4b26b-93aa-4a13-8117-4076776c8add",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
