{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c45146a1-92c7-4be5-aed9-09c93c0bd469",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "stty: 'standard input': Inappropriate ioctl for device\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from source.source.table_utils import (\n",
    "    collect_scores_into_dict,\n",
    "    extract_same_different_dataframes,\n",
    "    ood_detection_pairs_,\n",
    "    aggregate_over_measures,\n",
    ")\n",
    "from source.source.path_config import REPOSITORY_ROOT\n",
    "from source.metrics.constants import GName\n",
    "from source.losses.constants import LossName\n",
    "from source.datasets.constants import DatasetName\n",
    "from source.models.constants import ModelSource\n",
    "from IPython.display import display\n",
    "\n",
    "pd.set_option(\"display.max_rows\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0314bd84-4e4f-4468-b64e-699c8e850c67",
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_dataset_name = DatasetName.CIFAR10_NOISY_LABEL.value\n",
    "model_source = ModelSource.OUR_MODELS.value\n",
    "\n",
    "# full_ood_rocauc = pd.read_pickle(\n",
    "#     f\"{REPOSITORY_ROOT}/tables/central_tables/final/{ind_dataset_name}_{model_source}_full_ood_rocauc.pkl\"\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8bf4dcbb-eac9-497e-bb11-cc11d12ad156",
   "metadata": {},
   "outputs": [],
   "source": [
    "### FOR MISCLASSIFICATION\n",
    "\n",
    "full_ood_rocauc = pd.read_pickle(\n",
    "    f\"{REPOSITORY_ROOT}/tables/central_tables/final/{ind_dataset_name}_{model_source}_full_mis_rocauc.pkl\"\n",
    ")\n",
    "\n",
    "full_ood_rocauc[\"Dataset\"] = ind_dataset_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e2405a4a-b04f-45a6-8d90-00c00e5db338",
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_uq_measures = set([' '.join(el.split(' ')[1:]) for el in full_ood_rocauc.UQMetric.unique()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "48dff02d-ea7b-474f-bc37-60ac292def7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_ood_rocauc = full_ood_rocauc[~((full_ood_rocauc.UQMetric.str.find('central') > -1) & (full_ood_rocauc.base_rule == GName.ZERO_ONE_SCORE.value))] # because Central not defined for Zero-One score\n",
    "full_ood_rocauc = full_ood_rocauc[~((full_ood_rocauc.UQMetric.str.find('central') > -1) & (full_ood_rocauc.UQMetric.str.find('inner') > -1) & (full_ood_rocauc.base_rule == GName.BRIER_SCORE.value) & (full_ood_rocauc.UQMetric.str.find('ExcessRisk') > -1))] # because Central coincides with Inner for Brier\n",
    "\n",
    "full_ood_rocauc = full_ood_rocauc[full_ood_rocauc.LossFunction == LossName.CROSS_ENTROPY.value] # only specific loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fe79c41d-1c76-40de-93ac-76307ec5d711",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['UQMetric', 'LossFunction', 'RocAucScores_array', 'architecture',\n",
       "       'training_dataset', 'base_rule', 'RiskType', 'Dataset'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_ood_rocauc.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b23285b6-5b71-464a-924f-2dce720401d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# full_ood_rocauc.UQMetric.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ed4dfd5-7825-40f9-b551-19919a3ed832",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "29ee3dad-3933-46c4-988d-419787c9bdca",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_coinciding = None\n",
    "\n",
    "mapping_loss_to_g = {\n",
    "    LossName.CROSS_ENTROPY.value: GName.LOG_SCORE.value,\n",
    "    LossName.BRIER_SCORE.value: GName.BRIER_SCORE.value,\n",
    "    LossName.SPHERICAL_SCORE.value: GName.SPHERICAL_SCORE.value,\n",
    "}\n",
    "\n",
    "\n",
    "for loss_ in LossName:\n",
    "    loss = loss_.value\n",
    "    if df_coinciding is None:\n",
    "        df_coinciding = full_ood_rocauc[(full_ood_rocauc.LossFunction == loss) & (full_ood_rocauc.base_rule == mapping_loss_to_g[loss])]\n",
    "    else:\n",
    "        df_coinciding = pd.concat([df_coinciding, full_ood_rocauc[(full_ood_rocauc.LossFunction == loss) & (full_ood_rocauc.base_rule == mapping_loss_to_g[loss])]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "78d96942-b25c-4bb5-815c-3de41a89e53d",
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_elements = [i for i in np.array(full_ood_rocauc.index) if i not in np.array(df_coinciding.index)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6956cfb8-e885-4cf8-a2db-b1baba8a4afa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b80aa5c6-62b4-4252-903b-8899305d457e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_diff = full_ood_rocauc.loc[diff_elements]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbd109aa-ad4e-42c3-ad0d-c2be7ffc4137",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "020d4a39-f4eb-491c-b804-2a43581a71b8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2f7c11fe-d152-4568-b342-b3ad8516ff23",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_riskwise(dataframe):\n",
    "    \n",
    "    measure_dict = {}\n",
    "    for dataset_ in full_ood_rocauc.Dataset.unique():\n",
    "        measure_dict[dataset_] = {}\n",
    "        for uq_measure in unique_uq_measures:\n",
    "            vals = dataframe[dataframe.UQMetric.str.endswith(uq_measure) & (dataframe.Dataset == dataset_)]\n",
    "            \n",
    "            if (uq_measure.find('energy') > -1) & (len(vals) == 0):\n",
    "                continue\n",
    "    \n",
    "            measure_dict[dataset_][uq_measure] = {}\n",
    "    \n",
    "            mean_ = np.vstack(vals.RocAucScores_array.values).mean()\n",
    "    \n",
    "            measure_dict[dataset_][uq_measure] = mean_\n",
    "    return measure_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7c223c2f-ea3b-4883-9a69-515fee39cd0d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cifar10_noisy_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BayesRisk central</th>\n",
       "      <td>0.793217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk inner</th>\n",
       "      <td>0.796328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk outer</th>\n",
       "      <td>0.765180</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central central</th>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central inner</th>\n",
       "      <td>0.666374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central outer</th>\n",
       "      <td>0.720505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner central</th>\n",
       "      <td>0.655247</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner inner</th>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner outer</th>\n",
       "      <td>0.714054</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer central</th>\n",
       "      <td>0.710383</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer inner</th>\n",
       "      <td>0.735840</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer outer</th>\n",
       "      <td>0.727618</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central central</th>\n",
       "      <td>0.793217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central inner</th>\n",
       "      <td>0.791731</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central outer</th>\n",
       "      <td>0.785387</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner central</th>\n",
       "      <td>0.784687</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner inner</th>\n",
       "      <td>0.796328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner outer</th>\n",
       "      <td>0.788345</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer central</th>\n",
       "      <td>0.784687</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer inner</th>\n",
       "      <td>0.796328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer outer</th>\n",
       "      <td>0.788345</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy inner</th>\n",
       "      <td>0.748162</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy outer</th>\n",
       "      <td>0.746138</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            cifar10_noisy_label\n",
       "BayesRisk central                      0.793217\n",
       "BayesRisk inner                        0.796328\n",
       "BayesRisk outer                        0.765180\n",
       "ExcessRisk central central             0.500000\n",
       "ExcessRisk central inner               0.666374\n",
       "ExcessRisk central outer               0.720505\n",
       "ExcessRisk inner central               0.655247\n",
       "ExcessRisk inner inner                 0.500000\n",
       "ExcessRisk inner outer                 0.714054\n",
       "ExcessRisk outer central               0.710383\n",
       "ExcessRisk outer inner                 0.735840\n",
       "ExcessRisk outer outer                 0.727618\n",
       "TotalRisk central central              0.793217\n",
       "TotalRisk central inner                0.791731\n",
       "TotalRisk central outer                0.785387\n",
       "TotalRisk inner central                0.784687\n",
       "TotalRisk inner inner                  0.796328\n",
       "TotalRisk inner outer                  0.788345\n",
       "TotalRisk outer central                0.784687\n",
       "TotalRisk outer inner                  0.796328\n",
       "TotalRisk outer outer                  0.788345\n",
       "energy inner                           0.748162\n",
       "energy outer                           0.746138"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_stratified = pd.DataFrame.from_dict(mean_riskwise(full_ood_rocauc.copy())).sort_index()\n",
    "full_stratified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1db0dca9-44d2-4240-ae0f-1bc269e478ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cifar10_noisy_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BayesRisk central</th>\n",
       "      <td>0.804374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk inner</th>\n",
       "      <td>0.783553</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk outer</th>\n",
       "      <td>0.782573</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central central</th>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central inner</th>\n",
       "      <td>0.612817</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central outer</th>\n",
       "      <td>0.699839</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner central</th>\n",
       "      <td>0.592383</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner inner</th>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner outer</th>\n",
       "      <td>0.676975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer central</th>\n",
       "      <td>0.668474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer inner</th>\n",
       "      <td>0.696427</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer outer</th>\n",
       "      <td>0.683930</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central central</th>\n",
       "      <td>0.804374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central inner</th>\n",
       "      <td>0.802061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central outer</th>\n",
       "      <td>0.793767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner central</th>\n",
       "      <td>0.762649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner inner</th>\n",
       "      <td>0.783553</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner outer</th>\n",
       "      <td>0.757266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer central</th>\n",
       "      <td>0.762649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer inner</th>\n",
       "      <td>0.783553</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer outer</th>\n",
       "      <td>0.757266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy inner</th>\n",
       "      <td>0.748162</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy outer</th>\n",
       "      <td>0.746138</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            cifar10_noisy_label\n",
       "BayesRisk central                      0.804374\n",
       "BayesRisk inner                        0.783553\n",
       "BayesRisk outer                        0.782573\n",
       "ExcessRisk central central             0.500000\n",
       "ExcessRisk central inner               0.612817\n",
       "ExcessRisk central outer               0.699839\n",
       "ExcessRisk inner central               0.592383\n",
       "ExcessRisk inner inner                 0.500000\n",
       "ExcessRisk inner outer                 0.676975\n",
       "ExcessRisk outer central               0.668474\n",
       "ExcessRisk outer inner                 0.696427\n",
       "ExcessRisk outer outer                 0.683930\n",
       "TotalRisk central central              0.804374\n",
       "TotalRisk central inner                0.802061\n",
       "TotalRisk central outer                0.793767\n",
       "TotalRisk inner central                0.762649\n",
       "TotalRisk inner inner                  0.783553\n",
       "TotalRisk inner outer                  0.757266\n",
       "TotalRisk outer central                0.762649\n",
       "TotalRisk outer inner                  0.783553\n",
       "TotalRisk outer outer                  0.757266\n",
       "energy inner                           0.748162\n",
       "energy outer                           0.746138"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coinciding_stratified = pd.DataFrame.from_dict(mean_riskwise(df_coinciding.copy())).sort_index()\n",
    "coinciding_stratified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9510d6d2-8287-45b1-9b4b-a742accee832",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cifar10_noisy_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BayesRisk central</th>\n",
       "      <td>0.787639</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk inner</th>\n",
       "      <td>0.800586</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk outer</th>\n",
       "      <td>0.759383</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central central</th>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central inner</th>\n",
       "      <td>0.719932</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central outer</th>\n",
       "      <td>0.730838</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner central</th>\n",
       "      <td>0.718112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner inner</th>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner outer</th>\n",
       "      <td>0.726413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer central</th>\n",
       "      <td>0.731337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer inner</th>\n",
       "      <td>0.748978</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer outer</th>\n",
       "      <td>0.742180</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central central</th>\n",
       "      <td>0.787639</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central inner</th>\n",
       "      <td>0.786566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central outer</th>\n",
       "      <td>0.781197</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner central</th>\n",
       "      <td>0.795707</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner inner</th>\n",
       "      <td>0.800586</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner outer</th>\n",
       "      <td>0.798704</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer central</th>\n",
       "      <td>0.795707</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer inner</th>\n",
       "      <td>0.800586</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer outer</th>\n",
       "      <td>0.798704</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            cifar10_noisy_label\n",
       "BayesRisk central                      0.787639\n",
       "BayesRisk inner                        0.800586\n",
       "BayesRisk outer                        0.759383\n",
       "ExcessRisk central central             0.500000\n",
       "ExcessRisk central inner               0.719932\n",
       "ExcessRisk central outer               0.730838\n",
       "ExcessRisk inner central               0.718112\n",
       "ExcessRisk inner inner                 0.500000\n",
       "ExcessRisk inner outer                 0.726413\n",
       "ExcessRisk outer central               0.731337\n",
       "ExcessRisk outer inner                 0.748978\n",
       "ExcessRisk outer outer                 0.742180\n",
       "TotalRisk central central              0.787639\n",
       "TotalRisk central inner                0.786566\n",
       "TotalRisk central outer                0.781197\n",
       "TotalRisk inner central                0.795707\n",
       "TotalRisk inner inner                  0.800586\n",
       "TotalRisk inner outer                  0.798704\n",
       "TotalRisk outer central                0.795707\n",
       "TotalRisk outer inner                  0.800586\n",
       "TotalRisk outer outer                  0.798704"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "different_stratified = pd.DataFrame.from_dict(mean_riskwise(df_diff.copy())).sort_index()\n",
    "different_stratified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1e78a8ef-73f2-45da-af1d-7c977a9c1f7b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cifar10_noisy_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BayesRisk central</th>\n",
       "      <td>1.406561</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk inner</th>\n",
       "      <td>-1.604175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk outer</th>\n",
       "      <td>2.272977</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central central</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central inner</th>\n",
       "      <td>-8.037112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central outer</th>\n",
       "      <td>-2.868182</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner central</th>\n",
       "      <td>-9.594064</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner inner</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner outer</th>\n",
       "      <td>-5.192693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer central</th>\n",
       "      <td>-5.899430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer inner</th>\n",
       "      <td>-5.356210</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer outer</th>\n",
       "      <td>-6.004166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central central</th>\n",
       "      <td>1.406561</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central inner</th>\n",
       "      <td>1.304781</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central outer</th>\n",
       "      <td>1.067008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner central</th>\n",
       "      <td>-2.808557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner inner</th>\n",
       "      <td>-1.604175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner outer</th>\n",
       "      <td>-3.942312</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer central</th>\n",
       "      <td>-2.808561</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer inner</th>\n",
       "      <td>-1.604178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer outer</th>\n",
       "      <td>-3.942316</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy inner</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy outer</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            cifar10_noisy_label\n",
       "BayesRisk central                      1.406561\n",
       "BayesRisk inner                       -1.604175\n",
       "BayesRisk outer                        2.272977\n",
       "ExcessRisk central central             0.000000\n",
       "ExcessRisk central inner              -8.037112\n",
       "ExcessRisk central outer              -2.868182\n",
       "ExcessRisk inner central              -9.594064\n",
       "ExcessRisk inner inner                 0.000000\n",
       "ExcessRisk inner outer                -5.192693\n",
       "ExcessRisk outer central              -5.899430\n",
       "ExcessRisk outer inner                -5.356210\n",
       "ExcessRisk outer outer                -6.004166\n",
       "TotalRisk central central              1.406561\n",
       "TotalRisk central inner                1.304781\n",
       "TotalRisk central outer                1.067008\n",
       "TotalRisk inner central               -2.808557\n",
       "TotalRisk inner inner                 -1.604175\n",
       "TotalRisk inner outer                 -3.942312\n",
       "TotalRisk outer central               -2.808561\n",
       "TotalRisk outer inner                 -1.604178\n",
       "TotalRisk outer outer                 -3.942316\n",
       "energy inner                           0.000000\n",
       "energy outer                           0.000000"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "100 * (coinciding_stratified - full_stratified) / full_stratified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "85ccd51c-3f65-40bf-9b08-5db2c6c2895c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cifar10_noisy_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BayesRisk central</th>\n",
       "      <td>-0.703280</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk inner</th>\n",
       "      <td>0.534725</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk outer</th>\n",
       "      <td>-0.757659</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central central</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central inner</th>\n",
       "      <td>8.037112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central outer</th>\n",
       "      <td>1.434091</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner central</th>\n",
       "      <td>9.594064</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner inner</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner outer</th>\n",
       "      <td>1.730898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer central</th>\n",
       "      <td>2.949715</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer inner</th>\n",
       "      <td>1.785403</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer outer</th>\n",
       "      <td>2.001389</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central central</th>\n",
       "      <td>-0.703280</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central inner</th>\n",
       "      <td>-0.652390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central outer</th>\n",
       "      <td>-0.533504</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner central</th>\n",
       "      <td>1.404278</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner inner</th>\n",
       "      <td>0.534725</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner outer</th>\n",
       "      <td>1.314104</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer central</th>\n",
       "      <td>1.404280</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer inner</th>\n",
       "      <td>0.534726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer outer</th>\n",
       "      <td>1.314105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy inner</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy outer</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            cifar10_noisy_label\n",
       "BayesRisk central                     -0.703280\n",
       "BayesRisk inner                        0.534725\n",
       "BayesRisk outer                       -0.757659\n",
       "ExcessRisk central central             0.000000\n",
       "ExcessRisk central inner               8.037112\n",
       "ExcessRisk central outer               1.434091\n",
       "ExcessRisk inner central               9.594064\n",
       "ExcessRisk inner inner                 0.000000\n",
       "ExcessRisk inner outer                 1.730898\n",
       "ExcessRisk outer central               2.949715\n",
       "ExcessRisk outer inner                 1.785403\n",
       "ExcessRisk outer outer                 2.001389\n",
       "TotalRisk central central             -0.703280\n",
       "TotalRisk central inner               -0.652390\n",
       "TotalRisk central outer               -0.533504\n",
       "TotalRisk inner central                1.404278\n",
       "TotalRisk inner inner                  0.534725\n",
       "TotalRisk inner outer                  1.314104\n",
       "TotalRisk outer central                1.404280\n",
       "TotalRisk outer inner                  0.534726\n",
       "TotalRisk outer outer                  1.314105\n",
       "energy inner                                NaN\n",
       "energy outer                                NaN"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "100 * (different_stratified - full_stratified) / full_stratified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07b8b0c1-d7c1-4174-bb19-d11575545dd7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "b89e2eed-6e2f-46cf-a7ce-1b9db18111c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cifar10_noisy_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BayesRisk central</th>\n",
       "      <td>2.124785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk inner</th>\n",
       "      <td>-2.127523</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BayesRisk outer</th>\n",
       "      <td>3.053773</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central central</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central inner</th>\n",
       "      <td>-14.878428</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central outer</th>\n",
       "      <td>-4.241447</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner central</th>\n",
       "      <td>-17.508364</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner inner</th>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk inner outer</th>\n",
       "      <td>-6.805790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer central</th>\n",
       "      <td>-8.595599</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer inner</th>\n",
       "      <td>-7.016344</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer outer</th>\n",
       "      <td>-7.848476</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central central</th>\n",
       "      <td>2.124785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central inner</th>\n",
       "      <td>1.970023</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk central outer</th>\n",
       "      <td>1.609096</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner central</th>\n",
       "      <td>-4.154494</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner inner</th>\n",
       "      <td>-2.127523</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk inner outer</th>\n",
       "      <td>-5.188237</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer central</th>\n",
       "      <td>-4.154500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer inner</th>\n",
       "      <td>-2.127528</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TotalRisk outer outer</th>\n",
       "      <td>-5.188242</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy inner</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>energy outer</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                            cifar10_noisy_label\n",
       "BayesRisk central                      2.124785\n",
       "BayesRisk inner                       -2.127523\n",
       "BayesRisk outer                        3.053773\n",
       "ExcessRisk central central             0.000000\n",
       "ExcessRisk central inner             -14.878428\n",
       "ExcessRisk central outer              -4.241447\n",
       "ExcessRisk inner central             -17.508364\n",
       "ExcessRisk inner inner                 0.000000\n",
       "ExcessRisk inner outer                -6.805790\n",
       "ExcessRisk outer central              -8.595599\n",
       "ExcessRisk outer inner                -7.016344\n",
       "ExcessRisk outer outer                -7.848476\n",
       "TotalRisk central central              2.124785\n",
       "TotalRisk central inner                1.970023\n",
       "TotalRisk central outer                1.609096\n",
       "TotalRisk inner central               -4.154494\n",
       "TotalRisk inner inner                 -2.127523\n",
       "TotalRisk inner outer                 -5.188237\n",
       "TotalRisk outer central               -4.154500\n",
       "TotalRisk outer inner                 -2.127528\n",
       "TotalRisk outer outer                 -5.188242\n",
       "energy inner                                NaN\n",
       "energy outer                                NaN"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "relative_tab =  100 * (coinciding_stratified - different_stratified) / different_stratified\n",
    "\n",
    "relative_tab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "64e92b14-1148-4fe5-8ad4-19a8027b8d17",
   "metadata": {},
   "outputs": [],
   "source": [
    "# relative_tab = relative_tab.drop(columns=[DatasetName.CIFAR10_NOISY_LABEL.value, DatasetName.CIFAR10.value])\n",
    "# relative_tab = relative_tab.drop(columns=[DatasetName.TINY_IMAGENET.value])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "07181f03-f776-4963-b89c-126d7c62824d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{ll}\n",
      "\\toprule\n",
      " & cifar10_noisy_label \\\\\n",
      "\\midrule\n",
      "BayesRisk central & 2.125 \\\\\n",
      "BayesRisk inner & -2.128 \\\\\n",
      "BayesRisk outer & 3.054 \\\\\n",
      "ExcessRisk central central & 0.0 \\\\\n",
      "ExcessRisk central inner & -14.878 \\\\\n",
      "ExcessRisk central outer & -4.241 \\\\\n",
      "ExcessRisk inner central & -17.508 \\\\\n",
      "ExcessRisk inner inner & 0.0 \\\\\n",
      "ExcessRisk inner outer & -6.806 \\\\\n",
      "ExcessRisk outer central & -8.596 \\\\\n",
      "ExcessRisk outer inner & -7.016 \\\\\n",
      "ExcessRisk outer outer & -7.848 \\\\\n",
      "TotalRisk central central & 2.125 \\\\\n",
      "TotalRisk central inner & 1.97 \\\\\n",
      "TotalRisk central outer & 1.609 \\\\\n",
      "TotalRisk inner central & -4.154 \\\\\n",
      "TotalRisk inner inner & -2.128 \\\\\n",
      "TotalRisk inner outer & -5.188 \\\\\n",
      "TotalRisk outer central & -4.155 \\\\\n",
      "TotalRisk outer inner & -2.128 \\\\\n",
      "TotalRisk outer outer & -5.188 \\\\\n",
      "energy inner & nan \\\\\n",
      "energy outer & nan \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "# Merge the mean and std tables based on their index to create a combined LaTeX table.\n",
    "merged_table = relative_tab.round(3).copy()\n",
    "\n",
    "# For each numeric column, combine the mean and std in the format: mean \\pm std\n",
    "for col in relative_tab.columns:\n",
    "    merged_table[col] = relative_tab.round(3)[col].astype(str) #+ \" $\\\\pm$ \" + full_std_tab[col].round(3).astype(str)\n",
    "\n",
    "# Create LaTeX format\n",
    "latex_table = merged_table.to_latex(index=True, escape=False, float_format=\"%.2f\")\n",
    "\n",
    "# Output the resulting LaTeX table for the user\n",
    "print(latex_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1af76f24-2889-4a22-a0e8-e9f17922916a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "eee144e0-b66f-4371-af28-d3b49e63c2e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{ll}\n",
      "\\toprule\n",
      " & cifar10_noisy_label \\\\\n",
      "\\midrule\n",
      "\\(\\Rtildebayes^{(3)}\\) & 2.125 \\\\\n",
      "\\(\\Rtildebayes^{(2)}\\) & -2.128 \\\\\n",
      "\\(\\Rtildebayes^{(1)}\\) & 3.054 \\\\\n",
      "\\(\\Rtildeexc^{(3, 3)}\\) & 0.0 \\\\\n",
      "\\(\\Rtildeexc^{(3, 2)}\\) & -14.878 \\\\\n",
      "\\(\\Rtildeexc^{(3, 1)}\\) & -4.241 \\\\\n",
      "\\(\\Rtildeexc^{(2, 3)}\\) & -17.508 \\\\\n",
      "\\(\\Rtildeexc^{(2, 2)}\\) & 0.0 \\\\\n",
      "\\(\\Rtildeexc^{(2, 1)}\\) & -6.806 \\\\\n",
      "\\(\\Rtildeexc^{(1, 3)}\\) & -8.596 \\\\\n",
      "\\(\\Rtildeexc^{(1, 2)}\\) & -7.016 \\\\\n",
      "\\(\\Rtildeexc^{(1, 1)}\\) & -7.848 \\\\\n",
      "\\(\\Rtildetot^{(3, 3)}\\) & 2.125 \\\\\n",
      "\\(\\Rtildetot^{(3, 2)}\\) & 1.97 \\\\\n",
      "\\(\\Rtildetot^{(3, 1)}\\) & 1.609 \\\\\n",
      "\\(\\Rtildetot^{(2, 3)}\\) & -4.154 \\\\\n",
      "\\(\\Rtildetot^{(2, 2)}\\) & -2.128 \\\\\n",
      "\\(\\Rtildetot^{(2, 1)}\\) & -5.188 \\\\\n",
      "\\(\\Rtildetot^{(1, 3)}\\) & -4.155 \\\\\n",
      "\\(\\Rtildetot^{(1, 2)}\\) & -2.128 \\\\\n",
      "\\(\\Rtildetot^{(1, 1)}\\) & -5.188 \\\\\n",
      "\\( E(x;\\E_{\\theta}f_{\\theta}) \\) & nan \\\\\n",
      "\\( \\E_{\\theta} E(x;f_{\\theta}) \\) & nan \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "replace_dictionary = {\n",
    "    \"BayesRisk central\": r\"\\(\\Rtildebayes^{(3)}\\)\",\n",
    "    \"BayesRisk inner\": r\"\\(\\Rtildebayes^{(2)}\\)\",\n",
    "    \"BayesRisk outer\": r\"\\(\\Rtildebayes^{(1)}\\)\",\n",
    "    \"ExcessRisk central central\": r\"\\(\\Rtildeexc^{(3, 3)}\\)\",\n",
    "    \"ExcessRisk central inner\": r\"\\(\\Rtildeexc^{(3, 2)}\\)\",\n",
    "    \"ExcessRisk central outer\": r\"\\(\\Rtildeexc^{(3, 1)}\\)\",\n",
    "    \"ExcessRisk inner central\": r\"\\(\\Rtildeexc^{(2, 3)}\\)\",\n",
    "    \"ExcessRisk inner inner\": r\"\\(\\Rtildeexc^{(2, 2)}\\)\",\n",
    "    \"ExcessRisk inner outer\": r\"\\(\\Rtildeexc^{(2, 1)}\\)\",\n",
    "    \"ExcessRisk outer central\": r\"\\(\\Rtildeexc^{(1, 3)}\\)\",\n",
    "    \"ExcessRisk outer inner\": r\"\\(\\Rtildeexc^{(1, 2)}\\)\",\n",
    "    \"ExcessRisk outer outer\": r\"\\(\\Rtildeexc^{(1, 1)}\\)\",\n",
    "    \"TotalRisk central central\": r\"\\(\\Rtildetot^{(3, 3)}\\)\",\n",
    "    \"TotalRisk central inner\": r\"\\(\\Rtildetot^{(3, 2)}\\)\",\n",
    "    \"TotalRisk central outer\": r\"\\(\\Rtildetot^{(3, 1)}\\)\",\n",
    "    \"TotalRisk inner central\": r\"\\(\\Rtildetot^{(2, 3)}\\)\",\n",
    "    \"TotalRisk inner inner\": r\"\\(\\Rtildetot^{(2, 2)}\\)\",\n",
    "    \"TotalRisk inner outer\": r\"\\(\\Rtildetot^{(2, 1)}\\)\",\n",
    "    \"TotalRisk outer central\": r\"\\(\\Rtildetot^{(1, 3)}\\)\",\n",
    "    \"TotalRisk outer inner\": r\"\\(\\Rtildetot^{(1, 2)}\\)\",\n",
    "    \"TotalRisk outer outer\": r\"\\(\\Rtildetot^{(1, 1)}\\)\",\n",
    "    \"energy inner\": r\"\\( E(x;\\E_{\\theta}f_{\\theta}) \\)\",\n",
    "    \"energy outer\": r\"\\( \\E_{\\theta} E(x;f_{\\theta}) \\)\",\n",
    "    \n",
    "} \n",
    "for key in replace_dictionary.keys():\n",
    "    latex_table = latex_table.replace(key, replace_dictionary[key])\n",
    "\n",
    "print(latex_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b03953d7-5411-409d-a8f3-eaca12d5ec04",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a05c51c-a72f-4894-be51-67f37dea4ea7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd19c528-a7b1-4146-a7c8-f081d1c873f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "436e099f-ccac-4e27-98ad-7bca4a5ce36e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9935c833-98f3-4818-976d-dc7e62fc69f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>tiny_imagenet</th>\n",
       "      <th>imagenet_a</th>\n",
       "      <th>imagenet_r</th>\n",
       "      <th>imagenet_o</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BayesRisk outer</th>\n",
       "      <td>50.0</td>\n",
       "      <td>83.22</td>\n",
       "      <td>82.23</td>\n",
       "      <td>72.21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk outer outer</th>\n",
       "      <td>50.0</td>\n",
       "      <td>77.11</td>\n",
       "      <td>76.52</td>\n",
       "      <td>74.46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ExcessRisk central outer</th>\n",
       "      <td>50.0</td>\n",
       "      <td>77.56</td>\n",
       "      <td>77.02</td>\n",
       "      <td>74.46</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          tiny_imagenet  imagenet_a  imagenet_r  imagenet_o\n",
       "BayesRisk outer                    50.0       83.22       82.23       72.21\n",
       "ExcessRisk outer outer             50.0       77.11       76.52       74.46\n",
       "ExcessRisk central outer           50.0       77.56       77.02       74.46"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Cross Entropy\n",
    "(100 * coinciding_stratified.loc[[\"BayesRisk outer\", \"ExcessRisk outer outer\", \"ExcessRisk central outer\"]]).round(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d41ba4c8-2376-43b9-9be3-6c58b41d13ed",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44f66c45-1d0a-42bc-ae73-d5ff2dff9512",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad742004-8b43-4c3f-b74d-78258750cbd1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
