{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "dd785c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pathlib import Path\n",
    "import sys\n",
    "import os\n",
    "import ipywidgets as widgets\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "from basicsr.models import create_model\n",
    "from basicsr.utils import tensor2img\n",
    "from basicsr.utils.options import parse\n",
    "from basicsr.data import create_dataset\n",
    "import matplotlib.pyplot as plt\n",
    "from basicsr.metrics import calculate_psnr\n",
    "from basicsr.train import parse_options\n",
    "\n",
    "from basicsr.utils import (get_env_info, get_root_logger, get_time_str,\n",
    "                           make_exp_dirs)\n",
    "from os import path as osp\n",
    "import logging\n",
    "from basicsr.utils.options import dict2str\n",
    "import matplotlib.image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f6c2eae2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(yml_path, model_path):\n",
    "    # load model\n",
    "    opt_path = Path(yml_path) # your *training* options file\n",
    "\n",
    "    opt = parse(opt_path)\n",
    "    opt['dist'] = False\n",
    "    opt['path']['pretrain_network_g'] = model_path\n",
    "    opt['is_train'] = False\n",
    "    model = create_model(opt)\n",
    "    \n",
    "    return model\n",
    " \n",
    "\n",
    "def draw_stuff(model, dataset, i, name):\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(16, 6))\n",
    "    for ax in axs.flatten(): ax.axis('off')\n",
    "\n",
    "    lq = dataset[i]['lq'].unsqueeze(0)\n",
    "    gt = dataset[i]['gt'].unsqueeze(0)\n",
    "    y = model.net_g(lq)\n",
    "\n",
    "    psnr = np.round(calculate_psnr(y, gt, crop_border=0), 2)\n",
    "\n",
    "    axs[0].imshow(tensor2img(lq, rgb2bgr=False))\n",
    "    axs[0].set_title('noisy')\n",
    "\n",
    "    axs[1].imshow(tensor2img(y, rgb2bgr=False))\n",
    "    axs[1].set_title(f'output, PSNR={psnr}')\n",
    "\n",
    "    axs[2].imshow(tensor2img(gt, rgb2bgr=False))\n",
    "    axs[2].set_title('gt')\n",
    "\n",
    "    plt.suptitle(f'SIDD_val img {i}', fontsize=14)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8e95a4c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-18 13:40:54,016 INFO: Dataset PairedImageDataset - SIDD is created.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ecbb9aa2f6f4f2b95b6aefb16ab81e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=639, description='i', max=1279), Dropdown(description='version', options…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# load dataset\n",
    "\n",
    "dataset = create_dataset(dict(\n",
    "    name='SIDD',\n",
    "    type='PairedImageDataset',\n",
    "    dataroot_gt=f'/CascadedGaze/datasets/SIDD/val/gt_crops.lmdb',\n",
    "    dataroot_lq=f'/CascadedGaze/datasets/SIDD/val/input_crops.lmdb',\n",
    "    filename_tmpl='{}',\n",
    "    io_backend=dict(type='lmdb'),\n",
    "    scale=None,\n",
    "    phase=None\n",
    "))\n",
    "\n",
    "# view dataset\n",
    "@widgets.interact\n",
    "def f(i=(0, len(dataset)-1), version=['lq', 'gt']):\n",
    "    plt.imshow(tensor2img(dataset[i][version], rgb2bgr=False))\n",
    "    plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f9db1889",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 541"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "13204006",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "yml_path = \"/CascadedGaze/options/test/SIDD/CascadedGaze-SIDD.yml\"\n",
    "model_path = \"trained model path\"\n",
    "name = \"CascadedGaze\"\n",
    "model = load_model(yml_path, model_path)\n",
    "draw_stuff(model, dataset, i, name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
