{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3020c373-5e1d-4d0f-8ddc-03eed5d6897f",
   "metadata": {},
   "source": [
    "# OOD Detection\n",
    "\n",
    "The idea of this set of experiments is to evaluate a `ResNet50` backbone pretrained on ImageNet with CF data on OOD benchmarks such as `ImageNet Sketch`. If findings are worthy, we can extend it to other datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1705ee2d-5a18-4459-b2cb-549ca8f0fd7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "daec1f12-2dd6-4d56-9f73-0792036c4089",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from os.path import join, basename, dirname\n",
    "from glob import glob\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from natsort import natsorted\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "import torch\n",
    "import torchvision\n",
    "from torchvision.io import read_image\n",
    "from torchvision.utils import make_grid\n",
    "from torchvision import transforms\n",
    "\n",
    "from gradcam.utils import visualize_cam\n",
    "from gradcam import GradCAM, GradCAMpp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e4876252-32f4-4a55-9566-0e63a52b6fa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiment_utils import set_env, REPO_PATH, seed_everything\n",
    "set_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "a30083ca-a5fb-407e-8cc3-cd2b9da90a17",
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.image_utils import denormalize, show_single_image\n",
    "from experiments.imagenet_utils import IMModel, AverageEnsembleModel\n",
    "from experiments.ood_utils import validate as ood_validate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d03662f2-099c-4d6a-b086-5c22b8dbb10a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.serif\": [\"Computer Modern Roman\"],\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "856e9c98-3203-48e3-a5e8-f9ea82566fa9",
   "metadata": {},
   "source": [
    "### Set environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "201421c4-86df-4d01-85dc-8d142f78765a",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_everything(0)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d07a4a21-e2c8-44f1-8c11-2d73cc5b4e8a",
   "metadata": {},
   "source": [
    "### Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ee03a674-254a-4c16-9b9b-1c3985d8a420",
   "metadata": {},
   "outputs": [],
   "source": [
    "from imagenet.models.classifier_ensemble import InvariantEnsemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2ca30391-f793-44d9-bb0f-05b89807f637",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = InvariantEnsemble(\"resnet50\", pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "135e5d40-e67b-46c6-bc06-6afcfe500eda",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load weights from a checkpoint\n",
    "ckpt_path = \"imagenet/experiments/classifier_2022_01_19_15_36_sample_run/model_best.pth\"\n",
    "ckpt = torch.load(ckpt_path, map_location=\"cpu\")\n",
    "ckpt_state_dict = ckpt[\"state_dict\"]\n",
    "ckpt_state_dict = {k.replace(\"module.\", \"\"):v for k, v in ckpt_state_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "77bc8dc4-dc84-44a4-8f53-d5878078604c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(ckpt_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2d5aac87-4ea8-4797-bea2-f66beb9c554a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "99eab105-0797-46f7-969f-516f7904cbcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "73d65d99-d245-47f5-bb53-e0ae761386db",
   "metadata": {},
   "outputs": [],
   "source": [
    "shape_model = IMModel(base_model=model, mode=\"shape\")\n",
    "avg_model = AverageEnsembleModel(base_model=model)\n",
    "pytorch_model = torchvision.models.resnet50(pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "0a997ee1-2ea4-4b89-8c36-343d7b434899",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check models\n",
    "x = torch.randn((1, 3, 224, 224))\n",
    "\n",
    "y = shape_model(x)\n",
    "assert y.shape == torch.Size([1, 1000])\n",
    "\n",
    "y = avg_model(x)\n",
    "assert y.shape == torch.Size([1, 1000])\n",
    "\n",
    "y = pytorch_model(x)\n",
    "assert y.shape == torch.Size([1, 1000])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a59f818e-cb26-48e9-8035-e862ec0b3351",
   "metadata": {},
   "source": [
    "### Download data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ec14455-7225-44d1-a245-2b4bf848d1f5",
   "metadata": {},
   "source": [
    "Use `download_imagenet_sketch.sh` to download the data in folder `cgn_framework/imagenet/data/sketch/`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7da15043-9286-4fb4-a2a4-62c6bde7e01e",
   "metadata": {},
   "source": [
    "### Define the validation loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ca9a0770-4a52-4dd2-9df4-edbdecad09c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "valdir = join(join(REPO_PATH, \"cgn_framework\", \"imagenet\", \"data/sketch\"), 'val')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4bedbc16-fec7-4a3d-8fe3-be133c96f228",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_transform = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "202ac501-94cc-4923-99db-d98afc9b0622",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.ImageFolder(valdir, combined_transform),\n",
    "    batch_size=36,\n",
    "    shuffle=False,\n",
    "    drop_last=False,\n",
    "    num_workers=4,\n",
    "    pin_memory=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "eb916755-4e53-4383-9ea7-0453d16ff170",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1414"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae6666cc-9c84-470f-b7d2-dd8eb2ce78f5",
   "metadata": {},
   "source": [
    "### Benchmark performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "070f84a2-beea-46e0-a064-bed5f89ff8e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test: [   0/1414]\tTime 14.957 (14.957)\tAcc@1   5.56 (  5.56)\n",
      "Test: [  10/1414]\tTime  9.909 (10.761)\tAcc@1   5.56 ( 13.64)\n",
      "Test: [  20/1414]\tTime 15.492 (11.546)\tAcc@1   0.00 ( 15.48)\n",
      "Test: [  30/1414]\tTime 11.206 (11.750)\tAcc@1  11.11 ( 18.28)\n",
      "Test: [  40/1414]\tTime  9.406 (11.604)\tAcc@1  11.11 ( 16.73)\n",
      "Test: [  50/1414]\tTime 10.695 (11.611)\tAcc@1  19.44 ( 17.16)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "acc1 = ood_validate(val_loader, pytorch_model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82a7b1fb-58b0-4953-9f20-72c76958a03e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
