{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e533e968-3e3c-4184-95e5-459c8e2933c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nkotelevskii/github/uncertainty_from_proper_scoring_rules/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "stty: 'standard input': Inappropriate ioctl for device\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "import re\n",
    "import numpy as np\n",
    "\n",
    "sys.path.insert(0, \"./src/\")\n",
    "\n",
    "import pandas as pd\n",
    "from src.table_utils import (\n",
    "    collect_scores_into_dict,\n",
    "    extract_same_different_dataframes,\n",
    "    ood_detection_pairs_,\n",
    "    aggregate_over_measures,\n",
    ")\n",
    "from IPython.display import display\n",
    "\n",
    "pd.set_option(\"display.max_rows\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb9b9384-7b35-46c0-bad8-3d3d2fc15efe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "61d76aa3-8bac-46c8-b6f5-41e9f5941070",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_ood_rocauc = pd.read_csv(\"./tables/full_ood_rocauc.csv\", index_col=0)\n",
    "full_ood_rocauc = full_ood_rocauc[\n",
    "    full_ood_rocauc.Dataset != full_ood_rocauc.training_dataset\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b372ce15-ad8d-4fb3-9329-4778d4b29e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_ood_rocauc = full_ood_rocauc[~full_ood_rocauc.UQMetric.str.endswith(\"Inner Inner\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "44d3beff-622a-48ad-9143-643778876dd3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UQMetric</th>\n",
       "      <th>Dataset</th>\n",
       "      <th>LossFunction</th>\n",
       "      <th>RocAucScore</th>\n",
       "      <th>architecture</th>\n",
       "      <th>training_dataset</th>\n",
       "      <th>base_rule</th>\n",
       "      <th>RiskType</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1087</th>\n",
       "      <td>MV Neglog</td>\n",
       "      <td>svhn</td>\n",
       "      <td>Logscore</td>\n",
       "      <td>0.944981</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>Neglog</td>\n",
       "      <td>MV</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>62</th>\n",
       "      <td>Total Spherical Outer</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>Spherical</td>\n",
       "      <td>0.771593</td>\n",
       "      <td>vgg</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>Spherical</td>\n",
       "      <td>Total</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>290</th>\n",
       "      <td>Bayes Spherical Outer</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>Spherical</td>\n",
       "      <td>0.914115</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>missed_class_cifar10</td>\n",
       "      <td>Spherical</td>\n",
       "      <td>Bayes</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>Total Neglog Outer</td>\n",
       "      <td>svhn</td>\n",
       "      <td>Logscore</td>\n",
       "      <td>0.156566</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar10</td>\n",
       "      <td>Neglog</td>\n",
       "      <td>Total</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>941</th>\n",
       "      <td>BiasBI Brier</td>\n",
       "      <td>blurred_cifar100</td>\n",
       "      <td>Spherical</td>\n",
       "      <td>0.913722</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar10</td>\n",
       "      <td>Brier</td>\n",
       "      <td>BiasBI</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>451</th>\n",
       "      <td>Excess Brier Outer Inner</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>Logscore</td>\n",
       "      <td>0.717323</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>Brier</td>\n",
       "      <td>Excess</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>994</th>\n",
       "      <td>BiasBI Maxprob</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>Logscore</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar100</td>\n",
       "      <td>Maxprob</td>\n",
       "      <td>BiasBI</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>105</th>\n",
       "      <td>Total Neglog Inner</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>Brier</td>\n",
       "      <td>0.806482</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar100</td>\n",
       "      <td>Neglog</td>\n",
       "      <td>Total</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>235</th>\n",
       "      <td>Bayes Maxprob Outer</td>\n",
       "      <td>blurred_cifar100</td>\n",
       "      <td>Logscore</td>\n",
       "      <td>0.743500</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>Maxprob</td>\n",
       "      <td>Bayes</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>483</th>\n",
       "      <td>Excess Maxprob Outer Inner</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>Brier</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar100</td>\n",
       "      <td>Maxprob</td>\n",
       "      <td>Excess</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        UQMetric           Dataset LossFunction  RocAucScore  \\\n",
       "1087                   MV Neglog              svhn     Logscore     0.944981   \n",
       "62         Total Spherical Outer           cifar10    Spherical     0.771593   \n",
       "290        Bayes Spherical Outer          cifar100    Spherical     0.914115   \n",
       "37            Total Neglog Outer              svhn     Logscore     0.156566   \n",
       "941                 BiasBI Brier  blurred_cifar100    Spherical     0.913722   \n",
       "451     Excess Brier Outer Inner           cifar10     Logscore     0.717323   \n",
       "994               BiasBI Maxprob          cifar100     Logscore     0.500000   \n",
       "105           Total Neglog Inner           cifar10        Brier     0.806482   \n",
       "235          Bayes Maxprob Outer  blurred_cifar100     Logscore     0.743500   \n",
       "483   Excess Maxprob Outer Inner          cifar100        Brier     0.500000   \n",
       "\n",
       "     architecture      training_dataset  base_rule RiskType  \n",
       "1087     resnet18               cifar10     Neglog       MV  \n",
       "62            vgg              cifar100  Spherical    Total  \n",
       "290      resnet18  missed_class_cifar10  Spherical    Bayes  \n",
       "37       resnet18         noisy_cifar10     Neglog    Total  \n",
       "941      resnet18         noisy_cifar10      Brier   BiasBI  \n",
       "451      resnet18              cifar100      Brier   Excess  \n",
       "994      resnet18        noisy_cifar100    Maxprob   BiasBI  \n",
       "105      resnet18        noisy_cifar100     Neglog    Total  \n",
       "235      resnet18              cifar100    Maxprob    Bayes  \n",
       "483      resnet18        noisy_cifar100    Maxprob   Excess  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_ood_rocauc.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40adab12-7273-4e50-9886-8da51e017a45",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "525ae843-31b1-4235-8f68-5ae791ebd4ef",
   "metadata": {},
   "source": [
    "# How often Excess is better than Bayes in tasks of out-of-distribution detection?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f54b9ab5-b889-476b-8a6a-0955c45f632e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['Total', 'Bayes', 'Excess', 'Bregman Information',\n",
       "       'Reverse Bregman Information',\n",
       "       'Expected Pairwise Bregman Information', 'Bias', 'MV', 'MVBI',\n",
       "       'BiasBI'], dtype=object)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_ood_rocauc.RiskType.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4f9be112-35d4-4338-ad6c-28b9719e436e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# DROP NEGLOG\n",
    "\n",
    "# full_ood_rocauc = full_ood_rocauc[full_ood_rocauc.base_rule != 'Neglog']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ce8fab00-9ce3-412e-ae61-9e2f8ed2ea87",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXCESS_APPROXIMATION = \"Expected Pairwise Bregman Information\"  # \"Bregman Information\" \"Reverse Bregman Information\", \"Expected Pairwise Bregman Information\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6abcdf3a-73f2-4008-b93f-277e3c2414d6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6793b00-e789-4200-ab17-14689a7e7139",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d9b93bac-87c4-4e0f-9572-7af42a133a7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "excess_ood_scores = full_ood_rocauc[\n",
    "    (full_ood_rocauc.RiskType == EXCESS_APPROXIMATION)\n",
    "    & (full_ood_rocauc.Dataset != full_ood_rocauc.training_dataset)\n",
    "]\n",
    "excess_ood_scores = excess_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "32438d42-13ed-4fe7-88ce-f0406467550e",
   "metadata": {},
   "outputs": [],
   "source": [
    "bayes_inner_ood_scores = full_ood_rocauc[\n",
    "    (full_ood_rocauc.RiskType == \"Bayes\")\n",
    "    & (full_ood_rocauc.Dataset != full_ood_rocauc.training_dataset)\n",
    "    & (full_ood_rocauc.UQMetric.str.endswith(\"Inner\"))\n",
    "]\n",
    "bayes_inner_ood_scores = bayes_inner_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])\n",
    "\n",
    "bayes_outer_ood_scores = full_ood_rocauc[\n",
    "    (full_ood_rocauc.RiskType == \"Bayes\")\n",
    "    & (full_ood_rocauc.Dataset != full_ood_rocauc.training_dataset)\n",
    "    & (full_ood_rocauc.UQMetric.str.endswith(\"Outer\"))\n",
    "]\n",
    "bayes_outer_ood_scores = bayes_outer_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "202d61f9-0ba9-4534-9bf8-531d0c314a39",
   "metadata": {},
   "outputs": [],
   "source": [
    "merge_columns = [el for el in bayes_outer_ood_scores.columns if el != \"RocAucScore\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c96c257c-7177-4c86-86cf-8655c5b15c40",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Dataset', 'LossFunction', 'architecture', 'training_dataset', 'base_rule']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merge_columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28abf2f9-0687-4f0b-a02c-3f5df3a23001",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "7c1d54fc-a843-4cfa-9808-3d10c1df19b0",
   "metadata": {},
   "source": [
    "### Excess is better than Bayes Inner:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8945058b-7d30-4009-b8a8-ff9be3379a16",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_tab_ = excess_ood_scores.merge(\n",
    "    bayes_inner_ood_scores, on=merge_columns, suffixes=[\"Excess\", \"Bayes_Inner\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "71478ed0-a382-46c0-8191-46009e1fccaf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.27956989247311825\n"
     ]
    }
   ],
   "source": [
    "merged_tab_[\"compare_res\"] = (\n",
    "    merged_tab_[\"RocAucScoreExcess\"] > merged_tab_[\"RocAucScoreBayes_Inner\"]\n",
    ")\n",
    "print(merged_tab_[\"compare_res\"].mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d7e9c34-0bde-44a4-9109-7ee0bf9ff6b5",
   "metadata": {},
   "source": [
    "### Excess is better than Bayes Outer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d79ec61b-fc92-4ce4-8bf3-95b7861a13e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_tab_ = excess_ood_scores.merge(\n",
    "    bayes_outer_ood_scores, on=merge_columns, suffixes=[\"Excess\", \"Bayes_Outer\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "86af362b-5046-4154-ace5-6b84741f7e00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.34408602150537637\n"
     ]
    }
   ],
   "source": [
    "merged_tab_[\"compare_res\"] = (\n",
    "    merged_tab_[\"RocAucScoreExcess\"] > merged_tab_[\"RocAucScoreBayes_Outer\"]\n",
    ")\n",
    "print(merged_tab_[\"compare_res\"].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f307e32-66bd-4b11-99ed-d65c9f45fcd8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "031be319-b45b-46de-af5f-416093365884",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "7d9398b7-d2bc-4de7-b84c-eb49081f9d95",
   "metadata": {},
   "source": [
    "## Only soft-OOD scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f0f321e0-d5fd-441c-aef2-daaf8b9e6af3",
   "metadata": {},
   "outputs": [],
   "source": [
    "soft_ood_rocauc = full_ood_rocauc[\n",
    "    (\n",
    "        full_ood_rocauc.training_dataset.str.fullmatch(\"cifar10\")\n",
    "        & full_ood_rocauc.Dataset.str.fullmatch(\"blurred_cifar10\")\n",
    "    )\n",
    "    | (\n",
    "        full_ood_rocauc.training_dataset.str.fullmatch(\"cifar100\")\n",
    "        & full_ood_rocauc.Dataset.str.fullmatch(\"blurred_cifar100\")\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1abfff93-7590-4925-8f96-67e5cb386540",
   "metadata": {},
   "outputs": [],
   "source": [
    "excess_ood_scores = soft_ood_rocauc[\n",
    "    (soft_ood_rocauc.RiskType == EXCESS_APPROXIMATION)\n",
    "    & (soft_ood_rocauc.Dataset != soft_ood_rocauc.training_dataset)\n",
    "]\n",
    "excess_ood_scores = excess_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9916fda9-29bd-4960-8817-c279a1b4e9db",
   "metadata": {},
   "outputs": [],
   "source": [
    "bayes_inner_ood_scores = soft_ood_rocauc[\n",
    "    (soft_ood_rocauc.RiskType == \"Bayes\")\n",
    "    & (soft_ood_rocauc.Dataset != soft_ood_rocauc.training_dataset)\n",
    "    & (soft_ood_rocauc.UQMetric.str.endswith(\"Inner\"))\n",
    "]\n",
    "bayes_inner_ood_scores = bayes_inner_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])\n",
    "\n",
    "bayes_outer_ood_scores = soft_ood_rocauc[\n",
    "    (soft_ood_rocauc.RiskType == \"Bayes\")\n",
    "    & (soft_ood_rocauc.Dataset != soft_ood_rocauc.training_dataset)\n",
    "    & (soft_ood_rocauc.UQMetric.str.endswith(\"Outer\"))\n",
    "]\n",
    "bayes_outer_ood_scores = bayes_outer_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b6e9548c-601a-4898-ad94-3107f5bbe5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "merge_columns = [el for el in bayes_outer_ood_scores.columns if el != \"RocAucScore\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "edfe1d99-89f5-4c01-a542-cc69d710341c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Dataset', 'LossFunction', 'architecture', 'training_dataset', 'base_rule']"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merge_columns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba7fa053-749a-42a7-b8ee-ef29c4796621",
   "metadata": {},
   "source": [
    "### Excess is better than Bayes Inner:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bb1534b1-8045-4421-8fb2-222608649382",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_tab_ = excess_ood_scores.merge(\n",
    "    bayes_inner_ood_scores, on=merge_columns, suffixes=[\"Excess\", \"Bayes_Inner\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a7be31cd-a085-4cb8-ac21-ee2a0c6e3544",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8333333333333334\n"
     ]
    }
   ],
   "source": [
    "merged_tab_[\"compare_res\"] = (\n",
    "    merged_tab_[\"RocAucScoreExcess\"] > merged_tab_[\"RocAucScoreBayes_Inner\"]\n",
    ")\n",
    "print(merged_tab_[\"compare_res\"].mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fe2af2d-4293-447a-a911-9354e7d0419e",
   "metadata": {},
   "source": [
    "### Excess is better than Bayes Outer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "12b6fb61-3a0b-48df-8c87-7d32aeba4804",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_tab_ = excess_ood_scores.merge(\n",
    "    bayes_outer_ood_scores, on=merge_columns, suffixes=[\"Excess\", \"Bayes_Outer\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "876d90ab-c917-4e10-a119-590e77feb2c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8833333333333333\n"
     ]
    }
   ],
   "source": [
    "merged_tab_[\"compare_res\"] = (\n",
    "    merged_tab_[\"RocAucScoreExcess\"] > merged_tab_[\"RocAucScoreBayes_Outer\"]\n",
    ")\n",
    "print(merged_tab_[\"compare_res\"].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02964c4f-a5f6-43d2-b2d7-f45b51f1a13e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc6165f8-923f-4580-9da4-6c248ab6757e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1f317cc8-efc5-4ae2-b249-9b5e9ba87678",
   "metadata": {},
   "source": [
    "## Only hard-OOD scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "78cd9ad1-8950-478b-b74a-ecc347d4ab5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "hard_ood_rocauc = full_ood_rocauc[\n",
    "    ~(\n",
    "        (\n",
    "            full_ood_rocauc.training_dataset.str.fullmatch(\"cifar10\")\n",
    "            & full_ood_rocauc.Dataset.str.fullmatch(\"blurred_cifar10\")\n",
    "        )\n",
    "        | (\n",
    "            full_ood_rocauc.training_dataset.str.fullmatch(\"cifar100\")\n",
    "            & full_ood_rocauc.Dataset.str.fullmatch(\"blurred_cifar100\")\n",
    "        )\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "fd241f76-26be-41aa-890b-6dfbb6f15ebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "excess_ood_scores = hard_ood_rocauc[\n",
    "    (hard_ood_rocauc.RiskType == EXCESS_APPROXIMATION)\n",
    "    & (hard_ood_rocauc.Dataset != hard_ood_rocauc.training_dataset)\n",
    "]\n",
    "excess_ood_scores = excess_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "3c9866d7-31e3-48b9-8ba6-e6d280281d7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "bayes_inner_ood_scores = hard_ood_rocauc[\n",
    "    (hard_ood_rocauc.RiskType == \"Bayes\")\n",
    "    & (hard_ood_rocauc.Dataset != hard_ood_rocauc.training_dataset)\n",
    "    & (hard_ood_rocauc.UQMetric.str.endswith(\"Inner\"))\n",
    "]\n",
    "bayes_inner_ood_scores = bayes_inner_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])\n",
    "\n",
    "bayes_outer_ood_scores = hard_ood_rocauc[\n",
    "    (hard_ood_rocauc.RiskType == \"Bayes\")\n",
    "    & (hard_ood_rocauc.Dataset != hard_ood_rocauc.training_dataset)\n",
    "    & (hard_ood_rocauc.UQMetric.str.endswith(\"Outer\"))\n",
    "]\n",
    "bayes_outer_ood_scores = bayes_outer_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d7efc396-3207-450c-85e5-a34b87b8249e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Dataset', 'LossFunction', 'architecture', 'training_dataset', 'base_rule']"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merge_columns = [el for el in bayes_outer_ood_scores.columns if el != \"RocAucScore\"]\n",
    "merge_columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "448a5f93-070a-447c-9bcc-b4b3c9ce39d9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "29f9ed61-77f1-4826-8c87-e93dd55af936",
   "metadata": {},
   "source": [
    "### Excess is better than Bayes Inner:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "c18a8433-6707-47ba-8f2a-11f8553d6328",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_tab_ = excess_ood_scores.merge(\n",
    "    bayes_inner_ood_scores, on=merge_columns, suffixes=[\"Excess\", \"Bayes_Inner\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "c3f813e7-8805-4d41-befb-e9d604ddc2f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.19753086419753085\n"
     ]
    }
   ],
   "source": [
    "merged_tab_[\"compare_res\"] = (\n",
    "    merged_tab_[\"RocAucScoreExcess\"] > merged_tab_[\"RocAucScoreBayes_Inner\"]\n",
    ")\n",
    "print(merged_tab_[\"compare_res\"].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d31023e4-37fa-4041-a839-23e7cc41b338",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3906e270-ab48-4e6e-a3cb-21a5f5fb1304",
   "metadata": {},
   "source": [
    "### Excess is better than Bayes Outer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b20a41bc-dc3e-4b07-8688-e84994eb1dd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_tab_ = excess_ood_scores.merge(\n",
    "    bayes_outer_ood_scores, on=merge_columns, suffixes=[\"Excess\", \"Bayes_Outer\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "d9e8e154-1a25-493f-aa2e-d8adcc713383",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2641975308641975\n"
     ]
    }
   ],
   "source": [
    "merged_tab_[\"compare_res\"] = (\n",
    "    merged_tab_[\"RocAucScoreExcess\"] > merged_tab_[\"RocAucScoreBayes_Outer\"]\n",
    ")\n",
    "print(merged_tab_[\"compare_res\"].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82d52636-8713-4f16-a25e-5c84dee30a57",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d8e63e-33b8-4858-b1f0-a8ec5a6bdfd1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8901fc38-29ca-465e-b5cd-92224fb942f8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "ca229038-902e-42a9-94aa-808a4ab9fdad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2645161290322581\n",
      "0.26021505376344084\n",
      "0.27956989247311825\n"
     ]
    }
   ],
   "source": [
    "results = {}\n",
    "\n",
    "for EXCESS_APPROXIMATION in [\n",
    "    \"Bregman Information\",\n",
    "    \"Reverse Bregman Information\",\n",
    "    \"Expected Pairwise Bregman Information\",\n",
    "]:\n",
    "    results[EXCESS_APPROXIMATION] = {}\n",
    "\n",
    "    # How often Excess is better than Bayes in tasks of out-of-distribution detection?\n",
    "    excess_ood_scores = full_ood_rocauc[\n",
    "        (full_ood_rocauc.RiskType == EXCESS_APPROXIMATION)\n",
    "        & (full_ood_rocauc.Dataset != full_ood_rocauc.training_dataset)\n",
    "    ]\n",
    "    excess_ood_scores = excess_ood_scores.drop(columns=[\"UQMetric\", \"RiskType\"])\n",
    "\n",
    "    bayes_inner_ood_scores = full_ood_rocauc[\n",
    "        (full_ood_rocauc.RiskType == \"Bayes\")\n",
    "        & (full_ood_rocauc.Dataset != full_ood_rocauc.training_dataset)\n",
    "        & (full_ood_rocauc.UQMetric.str.endswith(\"Inner\"))\n",
    "    ]\n",
    "    bayes_inner_ood_scores = bayes_inner_ood_scores.drop(\n",
    "        columns=[\"UQMetric\", \"RiskType\"]\n",
    "    )\n",
    "\n",
    "    bayes_outer_ood_scores = full_ood_rocauc[\n",
    "        (full_ood_rocauc.RiskType == \"Bayes\")\n",
    "        & (full_ood_rocauc.Dataset != full_ood_rocauc.training_dataset)\n",
    "        & (full_ood_rocauc.UQMetric.str.endswith(\"Outer\"))\n",
    "    ]\n",
    "    bayes_outer_ood_scores = bayes_outer_ood_scores.drop(\n",
    "        columns=[\"UQMetric\", \"RiskType\"]\n",
    "    )\n",
    "\n",
    "    merge_columns = [el for el in bayes_outer_ood_scores.columns if el != \"RocAucScore\"]\n",
    "\n",
    "    # Excess is better than Bayes Inner:\n",
    "    merged_tab_ = excess_ood_scores.merge(\n",
    "        bayes_inner_ood_scores, on=merge_columns, suffixes=[\"Excess\", \"Bayes_Inner\"]\n",
    "    )\n",
    "\n",
    "    merged_tab_[\"compare_res\"] = (\n",
    "        merged_tab_[\"RocAucScoreExcess\"] > merged_tab_[\"RocAucScoreBayes_Inner\"]\n",
    "    )\n",
    "    print(merged_tab_[\"compare_res\"].mean())\n",
    "    results[EXCESS_APPROXIMATION] = merged_tab_[\"compare_res\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "b31fa389-4509-4053-bf38-477695efd0fc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Bregman Information': 0.2645161290322581,\n",
       " 'Reverse Bregman Information': 0.26021505376344084,\n",
       " 'Expected Pairwise Bregman Information': 0.27956989247311825}"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76290f7e-77d7-45f7-a721-79ca4db42979",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
