{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1df8c59f-d75b-4fc3-a6fc-6759f51c9da4",
   "metadata": {
    "tags": []
   },
   "source": [
    "## WeiPer on OpenOOD\n",
    "Set `dataset` to `cifar10` or `cifar100` \n",
    "and `postprocessor_name` to `weiper_density` or `weiper_kldiv` to test WeiPer on CIFAR\n",
    "## Hyperparameters\n",
    "The hyperparameters can be set and found here:\n",
    "`./OpenOOD/configs/postprocessors/weiper_density.yml`\n",
    "and here:\n",
    "`./OpenOOD/configs/postprocessors/weiper_kldiv.yml`\n",
    "\n",
    "Do not use the **APS_mode** as we did not optimize it yet for WeiPer+KLD and it will iterate over all 3600 parameter configurations. For most of the parameters resampling the perturbations and creating the perturbed logits (the most time consuming part) is not necessary. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "28c3e363-e171-42f8-950c-be25893b0f43",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from openood.networks import ResNet18_32x32\n",
    "import torch\n",
    "from openood.evaluation_api import Evaluator\n",
    "from openood.preprocessors import BasePreprocessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1b68c47d-70bd-4272-9d40-d7e7e49494de",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = \"cifar10\"\n",
    "model = ResNet18_32x32(num_classes=100 if dataset==\"cifar100\" else 10)\n",
    "model.load_state_dict(\n",
    "    torch.load(f'./OpenOOD/openood/results/{dataset}_resnet18_32x32_base_e100_lr0.1_default/s0/best.ckpt')\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "df631571-147d-4530-b63b-a700838379ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.to(\"cuda:0\")\n",
    "model.eval()\n",
    "\n",
    "torch.manual_seed(0)\n",
    "\n",
    "evaluator = Evaluator(\n",
    "    model, \n",
    "    id_name=dataset,\n",
    "    data_root=\"./OpenOOD/data/\",\n",
    "    config_root=\"./OpenOOD/configs/\",\n",
    "    preprocessor=None, \n",
    "    postprocessor_name='weiper_kldiv', \n",
    "    verbose=True, \n",
    "    APS_mode=False\n",
    ")\n",
    "evaluator.postprocessor.n_repeats = (\n",
    "    100 if dataset == \"cifar10\" else 30\n",
    ")\n",
    "\n",
    "metrics_ln = evaluator.eval_ood()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "19ef80c3-a8ee-4e02-bfe8-ccb9087d662d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>FPR@95</th>\n",
       "      <th>AUROC</th>\n",
       "      <th>AUPR_IN</th>\n",
       "      <th>AUPR_OUT</th>\n",
       "      <th>ACC</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>cifar100</th>\n",
       "      <td>36.955556</td>\n",
       "      <td>89.570943</td>\n",
       "      <td>87.019529</td>\n",
       "      <td>90.616275</td>\n",
       "      <td>95.222222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>tin</th>\n",
       "      <td>32.200000</td>\n",
       "      <td>90.976520</td>\n",
       "      <td>86.827113</td>\n",
       "      <td>92.821833</td>\n",
       "      <td>95.222222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nearood</th>\n",
       "      <td>34.577778</td>\n",
       "      <td>90.273732</td>\n",
       "      <td>86.923321</td>\n",
       "      <td>91.719054</td>\n",
       "      <td>95.222222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mnist</th>\n",
       "      <td>24.422222</td>\n",
       "      <td>93.487747</td>\n",
       "      <td>98.646319</td>\n",
       "      <td>77.261263</td>\n",
       "      <td>95.222222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>svhn</th>\n",
       "      <td>22.688889</td>\n",
       "      <td>92.838046</td>\n",
       "      <td>96.042819</td>\n",
       "      <td>88.824248</td>\n",
       "      <td>95.222222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>texture</th>\n",
       "      <td>26.477778</td>\n",
       "      <td>92.271205</td>\n",
       "      <td>84.746163</td>\n",
       "      <td>95.567147</td>\n",
       "      <td>95.222222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>places365</th>\n",
       "      <td>30.911111</td>\n",
       "      <td>90.584067</td>\n",
       "      <td>96.229039</td>\n",
       "      <td>80.682955</td>\n",
       "      <td>95.222222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>farood</th>\n",
       "      <td>26.125000</td>\n",
       "      <td>92.295266</td>\n",
       "      <td>93.916085</td>\n",
       "      <td>85.583904</td>\n",
       "      <td>95.222222</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              FPR@95      AUROC    AUPR_IN   AUPR_OUT        ACC\n",
       "cifar100   36.955556  89.570943  87.019529  90.616275  95.222222\n",
       "tin        32.200000  90.976520  86.827113  92.821833  95.222222\n",
       "nearood    34.577778  90.273732  86.923321  91.719054  95.222222\n",
       "mnist      24.422222  93.487747  98.646319  77.261263  95.222222\n",
       "svhn       22.688889  92.838046  96.042819  88.824248  95.222222\n",
       "texture    26.477778  92.271205  84.746163  95.567147  95.222222\n",
       "places365  30.911111  90.584067  96.229039  80.682955  95.222222\n",
       "farood     26.125000  92.295266  93.916085  85.583904  95.222222"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.metrics[\"ood\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a1a7a41-ccbc-4d3f-8143-52f32621eaaa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a507dae-4200-4edf-a33f-c233c7b64360",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
