{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bbe7db16-970e-4cdb-b721-08b241a2c60d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nkotelevskii/github/uncertainty_from_proper_scoring_rules/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "stty: 'standard input': Inappropriate ioctl for device\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import re\n",
    "import os\n",
    "\n",
    "sys.path.insert(0, \"src/\")\n",
    "sys.path.insert(1, \"external_repos/pytorch_cifar100/\")\n",
    "sys.path.insert(1, \"external_repos/pytorch_cifar10/\")\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm.auto import tqdm\n",
    "from src.data_utils import load_model_checkpoint, load_dict, make_load_path\n",
    "from src.postprocessing_utils import (\n",
    "    get_metrics_results,\n",
    "    uq_funcs_with_names,\n",
    "    get_uncertainty_scores,\n",
    "    get_predicted_labels,\n",
    "    make_aggregation,\n",
    "    get_raw_scores_dataframe,\n",
    "    ravel_df,\n",
    "    create_gt_embeddings,\n",
    "    get_sampled_combinations_uncertainty_scores,\n",
    ")\n",
    "from vectorizer_uncertainty_scores import posterior_predictive\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from itertools import combinations\n",
    "from IPython.display import display\n",
    "\n",
    "pd.set_option(\"display.max_rows\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "585709be-b266-460c-88e2-8615f382c9c8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77063d35-835e-4fee-962f-361cb9acbf28",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_dataset_names = [\n",
    "    \"cifar10\",\n",
    "    \"cifar100\",\n",
    "    \"noisy_cifar100\",\n",
    "    \"missed_class_cifar10\",\n",
    "    \"noisy_cifar10\",\n",
    "]\n",
    "temperature = 1.0\n",
    "model_ids = np.arange(20)\n",
    "list_extraction_datasets = [\n",
    "    \"cifar10\",\n",
    "    \"cifar100\",\n",
    "    \"svhn\",\n",
    "    \"blurred_cifar100\",\n",
    "    \"blurred_cifar10\",\n",
    "]\n",
    "list_ood_datasets = [el for el in list_extraction_datasets]\n",
    "loss_function_names = [\"brier_score\", \"cross_entropy\", \"spherical_score\"]\n",
    "use_different_approximations = False\n",
    "gt_prob_approx = \"same\"\n",
    "\n",
    "full_dataframe = None\n",
    "full_ood_rocauc_dataframe = None\n",
    "full_mis_rocauc_dataframe = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe6988f6-7e44-4e6b-b835-0a693b0b9745",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a833f5-0a39-4c36-a884-716b11c7305d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cbfeb3f-a449-40bd-a43e-03abaf835ce8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "236defb0-9508-43c1-ab89-bb58039de8f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ENSEMBLE_COMBINATIONS = [\n",
    "    (0, 1, 2, 3),\n",
    "    (4, 5, 6, 7),\n",
    "    (8, 9, 10, 11),\n",
    "    (12, 13, 14, 15),\n",
    "    (16, 17, 18, 19),\n",
    "]\n",
    "\n",
    "\n",
    "def get_predicted_labels(\n",
    "    embeddings_per_dataset: dict,\n",
    "    training_dataset_name: str,\n",
    "):\n",
    "    \"\"\"\n",
    "    The function returns predicted labels given embeddings for a given dataset\n",
    "    \"\"\"\n",
    "    pred_labels_dict = {}\n",
    "    for loss in embeddings_per_dataset:\n",
    "        pred_labels_dict[loss] = []\n",
    "        for comb in ENSEMBLE_COMBINATIONS:\n",
    "            pred_labels = np.argmax(\n",
    "                posterior_predictive(\n",
    "                    embeddings_per_dataset[loss][training_dataset_name][list(comb)]\n",
    "                )[0],\n",
    "                axis=-1,\n",
    "            )\n",
    "            pred_labels_dict[loss].append(pred_labels)\n",
    "    return pred_labels_dict\n",
    "\n",
    "\n",
    "def get_missclassification_dataframe(\n",
    "    ind_dataset: str,\n",
    "    uq_results: dict,\n",
    "    true_labels: np.ndarray,\n",
    "    pred_labels: list[np.ndarray],\n",
    ") -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    The function transforms uq_results dict into pd.Dataframe\n",
    "    with ROC AUC scores of misclassification detection.\n",
    "    \"\"\"\n",
    "    roc_auc_dict = {}\n",
    "\n",
    "    for uq_name, _ in uq_funcs_with_names:\n",
    "        roc_auc_dict[uq_name] = {}\n",
    "        # print(f'Misclassification computed via {uq_name}')\n",
    "\n",
    "        for loss_ in uq_results[uq_name].keys():\n",
    "            roc_auc_dict[uq_name][loss_] = []\n",
    "            for it_ in range(len(uq_results[uq_name][loss_][ind_dataset])):\n",
    "                y_true = (true_labels != pred_labels[loss_][it_]).astype(np.int32)\n",
    "                y_score = uq_results[uq_name][loss_][ind_dataset][it_]\n",
    "\n",
    "                score = roc_auc_score(y_true=y_true, y_score=y_score)\n",
    "                roc_auc_dict[uq_name][loss_].append(score)\n",
    "\n",
    "    data_list_misclassification = []\n",
    "    for metric_name, loss_function in roc_auc_dict.items():\n",
    "        for loss_function_name, values in loss_function.items():\n",
    "            data_list_misclassification.append(\n",
    "                (metric_name, loss_function_name, values)\n",
    "            )\n",
    "\n",
    "    # Create a DataFrame\n",
    "    df_misclassification = pd.DataFrame(\n",
    "        data_list_misclassification,\n",
    "        columns=[\n",
    "            \"UQMetric\",\n",
    "            \"LossFunction\",\n",
    "            \"RocAucScores_array\",\n",
    "        ],\n",
    "    )\n",
    "    return df_misclassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5012c3cb-2c2e-48f8-b5c5-40ce487b296b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "808770c1-4179-4600-8795-4ba34d7dedb5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "11a84049-9ae9-4a0b-8f72-ed6437be1658",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_dataframe = None\n",
    "\n",
    "for training_dataset_name in training_dataset_names:\n",
    "    if training_dataset_name not in [\n",
    "        \"missed_class_cifar10\",\n",
    "        \"noisy_cifar10\",\n",
    "        \"noisy_cifar100\",\n",
    "    ]:\n",
    "        architectures = [\"resnet18\", \"vgg\"]\n",
    "        training_dataset_name_aux = training_dataset_name\n",
    "    else:\n",
    "        architectures = [\"resnet18\"]\n",
    "        training_dataset_name_aux = training_dataset_name.split(\"_\")[-1]\n",
    "    for architecture in architectures:\n",
    "        ###\n",
    "        folder_path = make_load_path(\n",
    "            architecture=architecture,\n",
    "            dataset_name=training_dataset_name,\n",
    "            loss_function_name=\"NaN\",\n",
    "            model_id=\"NaN\",\n",
    "        )\n",
    "        extracted_embeddings_file_path = os.path.join(\n",
    "            *folder_path.split(\"/\")[:-3],\n",
    "            \"extracted_information_for_notebook_combinations.pkl\",\n",
    "        )\n",
    "\n",
    "        res_dict = load_dict(extracted_embeddings_file_path)\n",
    "        uq_results, embeddings_per_dataset, targets_per_dataset = (\n",
    "            res_dict[\"uq_results\"],\n",
    "            res_dict[\"embeddings_per_dataset\"],\n",
    "            res_dict[\"targets_per_dataset\"],\n",
    "        )\n",
    "        ###\n",
    "\n",
    "        max_ind = int(\n",
    "            targets_per_dataset[training_dataset_name_aux].shape[0] / len(model_ids)\n",
    "        )\n",
    "        true_labels = targets_per_dataset[training_dataset_name_aux][:max_ind]\n",
    "\n",
    "        pred_labels = get_predicted_labels(\n",
    "            embeddings_per_dataset=embeddings_per_dataset,\n",
    "            training_dataset_name=training_dataset_name_aux,\n",
    "        )\n",
    "\n",
    "        df_misclassification = get_missclassification_dataframe(\n",
    "            ind_dataset=training_dataset_name_aux,\n",
    "            uq_results=uq_results,\n",
    "            true_labels=true_labels,\n",
    "            pred_labels=pred_labels,\n",
    "        )\n",
    "        df_misclassification[\"architecture\"] = architecture\n",
    "        df_misclassification[\"training_dataset\"] = training_dataset_name\n",
    "\n",
    "        if full_mis_rocauc_dataframe is None:\n",
    "            full_mis_rocauc_dataframe = df_misclassification\n",
    "        else:\n",
    "            full_mis_rocauc_dataframe = pd.concat(\n",
    "                [full_mis_rocauc_dataframe, df_misclassification]\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6115b259-2b16-4bb4-a7b9-a02f06b24227",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9ce1ea99-fb18-4e90-a376-711bedab2ed1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1575, 5)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_mis_rocauc_dataframe.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bb70f28b-29b5-4656-836c-49cbb996edab",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_mis_rocauc_dataframe[\"RocAucScoresMean\"] = full_mis_rocauc_dataframe[\n",
    "    \"RocAucScores_array\"\n",
    "].apply(lambda x: np.array(x).mean())\n",
    "full_mis_rocauc_dataframe[\"RocAucScoresStd\"] = full_mis_rocauc_dataframe[\n",
    "    \"RocAucScores_array\"\n",
    "].apply(lambda x: np.array(x).std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c65936c9-0ea3-4b35-96b4-ead34dc28569",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UQMetric</th>\n",
       "      <th>LossFunction</th>\n",
       "      <th>RocAucScores_array</th>\n",
       "      <th>architecture</th>\n",
       "      <th>training_dataset</th>\n",
       "      <th>RocAucScoresMean</th>\n",
       "      <th>RocAucScoresStd</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>Bayes Maxprob Inner</td>\n",
       "      <td>brier_score</td>\n",
       "      <td>[0.8292068551460816, 0.8203732300207724, 0.740...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar10</td>\n",
       "      <td>0.797500</td>\n",
       "      <td>0.031445</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>188</th>\n",
       "      <td>BiasBI Brier</td>\n",
       "      <td>spherical_score</td>\n",
       "      <td>[0.6697303310739664, 0.7665530918623402, 0.618...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar10</td>\n",
       "      <td>0.674174</td>\n",
       "      <td>0.062675</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>Reverse Bregman Information Brier</td>\n",
       "      <td>brier_score</td>\n",
       "      <td>[0.8423573351457968, 0.8836318662519729, 0.943...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>missed_class_cifar10</td>\n",
       "      <td>0.910478</td>\n",
       "      <td>0.041025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>124</th>\n",
       "      <td>Bregman Information Logscore</td>\n",
       "      <td>cross_entropy</td>\n",
       "      <td>[0.7724285880541444, 0.7467192091586558, 0.754...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar100</td>\n",
       "      <td>0.759900</td>\n",
       "      <td>0.010288</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>111</th>\n",
       "      <td>Excess Maxprob Inner Inner</td>\n",
       "      <td>brier_score</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.5]</td>\n",
       "      <td>vgg</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>190</th>\n",
       "      <td>Bias Maxprob</td>\n",
       "      <td>cross_entropy</td>\n",
       "      <td>[0.1438103775226446, 0.1455196565955138, 0.142...</td>\n",
       "      <td>vgg</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>0.144981</td>\n",
       "      <td>0.002077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>Excess Logscore Inner Outer</td>\n",
       "      <td>brier_score</td>\n",
       "      <td>[0.7269731511109783, 0.7134475324575996, 0.706...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>noisy_cifar100</td>\n",
       "      <td>0.716347</td>\n",
       "      <td>0.007406</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>197</th>\n",
       "      <td>MVBI Maxprob</td>\n",
       "      <td>spherical_score</td>\n",
       "      <td>[0.8069556194667709, 0.8210779266466746, 0.814...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>0.814200</td>\n",
       "      <td>0.010249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>Bayes Neglog Inner</td>\n",
       "      <td>brier_score</td>\n",
       "      <td>[0.9271426693388815, 0.9275244647036068, 0.927...</td>\n",
       "      <td>vgg</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>0.928091</td>\n",
       "      <td>0.001107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>211</th>\n",
       "      <td>BiasBI Spherical</td>\n",
       "      <td>cross_entropy</td>\n",
       "      <td>[0.8578697759800266, 0.8641782603994588, 0.861...</td>\n",
       "      <td>vgg</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>0.861691</td>\n",
       "      <td>0.002612</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              UQMetric     LossFunction  \\\n",
       "42                 Bayes Maxprob Inner      brier_score   \n",
       "188                       BiasBI Brier  spherical_score   \n",
       "135  Reverse Bregman Information Brier      brier_score   \n",
       "124       Bregman Information Logscore    cross_entropy   \n",
       "111         Excess Maxprob Inner Inner      brier_score   \n",
       "190                       Bias Maxprob    cross_entropy   \n",
       "66         Excess Logscore Inner Outer      brier_score   \n",
       "197                       MVBI Maxprob  spherical_score   \n",
       "48                  Bayes Neglog Inner      brier_score   \n",
       "211                   BiasBI Spherical    cross_entropy   \n",
       "\n",
       "                                    RocAucScores_array architecture  \\\n",
       "42   [0.8292068551460816, 0.8203732300207724, 0.740...     resnet18   \n",
       "188  [0.6697303310739664, 0.7665530918623402, 0.618...     resnet18   \n",
       "135  [0.8423573351457968, 0.8836318662519729, 0.943...     resnet18   \n",
       "124  [0.7724285880541444, 0.7467192091586558, 0.754...     resnet18   \n",
       "111                          [0.5, 0.5, 0.5, 0.5, 0.5]          vgg   \n",
       "190  [0.1438103775226446, 0.1455196565955138, 0.142...          vgg   \n",
       "66   [0.7269731511109783, 0.7134475324575996, 0.706...     resnet18   \n",
       "197  [0.8069556194667709, 0.8210779266466746, 0.814...     resnet18   \n",
       "48   [0.9271426693388815, 0.9275244647036068, 0.927...          vgg   \n",
       "211  [0.8578697759800266, 0.8641782603994588, 0.861...          vgg   \n",
       "\n",
       "         training_dataset  RocAucScoresMean  RocAucScoresStd  \n",
       "42          noisy_cifar10          0.797500         0.031445  \n",
       "188         noisy_cifar10          0.674174         0.062675  \n",
       "135  missed_class_cifar10          0.910478         0.041025  \n",
       "124        noisy_cifar100          0.759900         0.010288  \n",
       "111               cifar10          0.500000         0.000000  \n",
       "190               cifar10          0.144981         0.002077  \n",
       "66         noisy_cifar100          0.716347         0.007406  \n",
       "197               cifar10          0.814200         0.010249  \n",
       "48                cifar10          0.928091         0.001107  \n",
       "211              cifar100          0.861691         0.002612  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_mis_rocauc_dataframe.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7be1cb0-6131-4a80-bbe3-de0a84b1fb6d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "44990d36-5396-4c0c-afe8-98f0b982a83b",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_score_dict = {\n",
    "    \"cross_entropy\": \"Logscore\",\n",
    "    \"brier_score\": \"Brier\",\n",
    "    \"spherical_score\": \"Spherical\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "96e73977-79a2-4e6d-889e-f716a956f11e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pattern_baserule = r\"(Logscore|Brier|Neglog|Maxprob|Spherical)\"\n",
    "pattern_risk = r\"(Total|Bayes|Excess|Reverse Bregman Information|Bregman Information|Expected Pairwise Bregman Information|MVBI|MV|BiasBI|Bias)\"\n",
    "\n",
    "full_mis_rocauc_dataframe[\"base_rule\"] = full_mis_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_baserule)\n",
    "full_mis_rocauc_dataframe[\"RiskType\"] = full_mis_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_risk)\n",
    "full_mis_rocauc_dataframe[\"LossFunction\"] = full_mis_rocauc_dataframe[\n",
    "    \"LossFunction\"\n",
    "].replace(base_score_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d03b97b4-8889-4966-8ddc-5f82233076e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_mis_rocauc_dataframe.to_csv(\"./tables/full_mis_rocauc_with_std.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a2b010d-1256-4e78-bed0-9c4ccb62db29",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e0a18ef-fe4f-4434-bb21-470252828bd6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "cfacea76-0f50-4beb-aac1-dffcd5f6648b",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_mis_rocauc = pd.read_csv(\"./tables/full_mis_rocauc_with_std.csv\", index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "14253d9f-54ef-4be9-991a-b8d2a886f4a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['cifar10', 'cifar100', 'noisy_cifar100', 'missed_class_cifar10',\n",
       "       'noisy_cifar10'], dtype=object)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_mis_rocauc.training_dataset.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce54a302-dc44-49e0-ae5c-7655933eda68",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e71883eb-1fdc-453c-8ebe-e4e464f938ba",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "7e86c719-93c2-4361-a895-8738a098613a",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_mis_rocauc = full_mis_rocauc[~full_mis_rocauc.UQMetric.str.endswith(\"Inner Inner\")]\n",
    "# full_mis_rocauc = full_mis_rocauc[full_mis_rocauc.base_rule != 'Neglog']\n",
    "\n",
    "# full_mis_rocauc.loc[(full_mis_rocauc.RiskType == \"Bayes\") & full_mis_rocauc.UQMetric.str.endswith(\"Outer\"), \"RiskType\"] = 'Bayes Outer'\n",
    "# full_mis_rocauc.loc[(full_mis_rocauc.RiskType == \"Bayes\") & full_mis_rocauc.UQMetric.str.endswith(\"Inner\"), \"RiskType\"] = 'Bayes Inner'\n",
    "\n",
    "# full_mis_rocauc.loc[(full_mis_rocauc.RiskType == \"Total\") & full_mis_rocauc.UQMetric.str.endswith(\"Outer\"), \"RiskType\"] = 'Total Outer'\n",
    "# full_mis_rocauc.loc[(full_mis_rocauc.RiskType == \"Total\") & full_mis_rocauc.UQMetric.str.endswith(\"Inner\"), \"RiskType\"] = 'Total Inner'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "9773c93b-6aa3-4766-a27e-6a88e123c7ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import re\n",
    "import numpy as np\n",
    "\n",
    "sys.path.insert(0, \"src/\")\n",
    "\n",
    "import pandas as pd\n",
    "from src.table_utils import (\n",
    "    extract_same_different_dataframes,\n",
    "    aggregate_over_measures,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "04e7804c-8600-4e5c-8b59-05666a9c5823",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_scores_into_dict_miss_with_std(\n",
    "    dataframes_list_,\n",
    "):\n",
    "    scores_dict_ = defaultdict(\n",
    "        list, {val: [] for val in dataframes_list_[0].RiskType.unique()}\n",
    "    )\n",
    "    std_dict_ = defaultdict(\n",
    "        list, {val: [] for val in dataframes_list_[0].RiskType.unique()}\n",
    "    )\n",
    "\n",
    "    for dataframe_ in dataframes_list_:\n",
    "        for ind in dataframe_.training_dataset.unique():\n",
    "            df_aux_ = dataframe_[(dataframe_[\"training_dataset\"] == ind)]\n",
    "\n",
    "            mean_rocauc_dict = dict(\n",
    "                df_aux_.groupby(by=[\"RiskType\"])\n",
    "                .agg({\"RocAucScoresMean\": [\"mean\"]})[(\"RocAucScoresMean\", \"mean\")]\n",
    "                .reset_index()\n",
    "                .values\n",
    "            )\n",
    "            std_rocauc_dict = dict(\n",
    "                df_aux_.groupby(by=[\"RiskType\"])\n",
    "                .agg({\"RocAucScoresStd\": [\"mean\"]})[(\"RocAucScoresStd\", \"mean\")]\n",
    "                .reset_index()\n",
    "                .values\n",
    "            )\n",
    "            next_iter = True\n",
    "            for k in mean_rocauc_dict:\n",
    "                if k in scores_dict_:\n",
    "                    scores_dict_[k].append(mean_rocauc_dict[k])\n",
    "                    std_dict_[k].append(std_rocauc_dict[k])\n",
    "                    next_iter = False\n",
    "            if next_iter:\n",
    "                continue\n",
    "\n",
    "            scores_dict_[\"InD\"].append(ind)\n",
    "            scores_dict_[\"ScoringRule\"].append(df_aux_[\"LossFunction\"].unique())\n",
    "\n",
    "            std_dict_[\"InD\"].append(ind)\n",
    "            std_dict_[\"ScoringRule\"].append(df_aux_[\"LossFunction\"].unique())\n",
    "    return scores_dict_, std_dict_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "741530e7-d8a7-4893-a5a4-b82f3eb527d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_df = extract_same_different_dataframes(\n",
    "    dataframe_=full_mis_rocauc,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "05678662-12ca-4997-86c8-b4461da54606",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "same_dict_mean, same_dict_std = collect_scores_into_dict_miss_with_std(\n",
    "    dataframes_list_=[\n",
    "        grouped_df.logscore_logscore,\n",
    "        grouped_df.brier_brier,\n",
    "        grouped_df.spherical_spherical,\n",
    "    ],\n",
    ")\n",
    "same_df_mean = pd.DataFrame.from_dict(same_dict_mean)\n",
    "same_df_std = pd.DataFrame.from_dict(same_dict_std)\n",
    "\n",
    "same_agg_df_mean = aggregate_over_measures(\n",
    "    dataframe_=same_df_mean,\n",
    "    agg_func_=\"mean\",\n",
    "    by_=[\"InD\"],\n",
    ")\n",
    "same_agg_df_std = aggregate_over_measures(\n",
    "    dataframe_=same_df_std,\n",
    "    agg_func_=\"mean\",\n",
    "    by_=[\"InD\"],\n",
    ")\n",
    "\n",
    "different_dict_mean, different_dict_std = collect_scores_into_dict_miss_with_std(\n",
    "    dataframes_list_=[\n",
    "        grouped_df.logscore_not_logscore,\n",
    "        grouped_df.brier_not_brier,\n",
    "        grouped_df.spherical_not_spherical,\n",
    "    ],\n",
    ")\n",
    "different_df_mean = pd.DataFrame.from_dict(different_dict_mean)\n",
    "different_df_std = pd.DataFrame.from_dict(different_dict_std)\n",
    "\n",
    "different_agg_df_mean = aggregate_over_measures(\n",
    "    dataframe_=different_df_mean,\n",
    "    agg_func_=\"mean\",\n",
    "    by_=[\"InD\"],\n",
    ")\n",
    "different_agg_df_std = aggregate_over_measures(\n",
    "    dataframe_=different_df_std,\n",
    "    agg_func_=\"mean\",\n",
    "    by_=[\"InD\"],\n",
    ")\n",
    "\n",
    "all_dict_mean, all_dict_std = collect_scores_into_dict_miss_with_std(\n",
    "    dataframes_list_=[\n",
    "        full_mis_rocauc,\n",
    "    ],\n",
    ")\n",
    "all_df_mean = pd.DataFrame.from_dict(all_dict_mean)\n",
    "all_df_std = pd.DataFrame.from_dict(all_dict_std)\n",
    "\n",
    "all_agg_df_mean = aggregate_over_measures(\n",
    "    dataframe_=all_df_mean,\n",
    "    agg_func_=\"mean\",\n",
    "    by_=[\"InD\"],\n",
    ")\n",
    "all_agg_df_std = aggregate_over_measures(\n",
    "    dataframe_=all_df_std,\n",
    "    agg_func_=\"mean\",\n",
    "    by_=[\"InD\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "dd55cd1c-746c-441f-9b53-60e13315b5bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Total</th>\n",
       "      <th>Bayes</th>\n",
       "      <th>Excess</th>\n",
       "      <th>Bregman Information</th>\n",
       "      <th>Reverse Bregman Information</th>\n",
       "      <th>Expected Pairwise Bregman Information</th>\n",
       "      <th>Bias</th>\n",
       "      <th>MV</th>\n",
       "      <th>MVBI</th>\n",
       "      <th>BiasBI</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>InD</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>cifar10</th>\n",
       "      <td>0.935947</td>\n",
       "      <td>0.934619</td>\n",
       "      <td>0.897459</td>\n",
       "      <td>0.897161</td>\n",
       "      <td>0.896719</td>\n",
       "      <td>0.898497</td>\n",
       "      <td>0.660546</td>\n",
       "      <td>0.839788</td>\n",
       "      <td>0.902130</td>\n",
       "      <td>0.761394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cifar100</th>\n",
       "      <td>0.841424</td>\n",
       "      <td>0.854797</td>\n",
       "      <td>0.729444</td>\n",
       "      <td>0.738183</td>\n",
       "      <td>0.720534</td>\n",
       "      <td>0.729614</td>\n",
       "      <td>0.510126</td>\n",
       "      <td>0.676550</td>\n",
       "      <td>0.736675</td>\n",
       "      <td>0.587066</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>missed_class_cifar10</th>\n",
       "      <td>0.911395</td>\n",
       "      <td>0.899189</td>\n",
       "      <td>0.880219</td>\n",
       "      <td>0.877316</td>\n",
       "      <td>0.881235</td>\n",
       "      <td>0.882104</td>\n",
       "      <td>0.656732</td>\n",
       "      <td>0.824091</td>\n",
       "      <td>0.884527</td>\n",
       "      <td>0.754405</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>noisy_cifar10</th>\n",
       "      <td>0.777694</td>\n",
       "      <td>0.784407</td>\n",
       "      <td>0.688842</td>\n",
       "      <td>0.690688</td>\n",
       "      <td>0.685141</td>\n",
       "      <td>0.690698</td>\n",
       "      <td>0.511371</td>\n",
       "      <td>0.639119</td>\n",
       "      <td>0.696322</td>\n",
       "      <td>0.595381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>noisy_cifar100</th>\n",
       "      <td>0.802213</td>\n",
       "      <td>0.818593</td>\n",
       "      <td>0.638102</td>\n",
       "      <td>0.643024</td>\n",
       "      <td>0.633539</td>\n",
       "      <td>0.637743</td>\n",
       "      <td>0.460808</td>\n",
       "      <td>0.595676</td>\n",
       "      <td>0.649527</td>\n",
       "      <td>0.513263</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                         Total     Bayes    Excess Bregman Information  \\\n",
       "                          mean      mean      mean                mean   \n",
       "InD                                                                      \n",
       "cifar10               0.935947  0.934619  0.897459            0.897161   \n",
       "cifar100              0.841424  0.854797  0.729444            0.738183   \n",
       "missed_class_cifar10  0.911395  0.899189  0.880219            0.877316   \n",
       "noisy_cifar10         0.777694  0.784407  0.688842            0.690688   \n",
       "noisy_cifar100        0.802213  0.818593  0.638102            0.643024   \n",
       "\n",
       "                     Reverse Bregman Information  \\\n",
       "                                            mean   \n",
       "InD                                                \n",
       "cifar10                                 0.896719   \n",
       "cifar100                                0.720534   \n",
       "missed_class_cifar10                    0.881235   \n",
       "noisy_cifar10                           0.685141   \n",
       "noisy_cifar100                          0.633539   \n",
       "\n",
       "                     Expected Pairwise Bregman Information      Bias  \\\n",
       "                                                      mean      mean   \n",
       "InD                                                                    \n",
       "cifar10                                           0.898497  0.660546   \n",
       "cifar100                                          0.729614  0.510126   \n",
       "missed_class_cifar10                              0.882104  0.656732   \n",
       "noisy_cifar10                                     0.690698  0.511371   \n",
       "noisy_cifar100                                    0.637743  0.460808   \n",
       "\n",
       "                            MV      MVBI    BiasBI  \n",
       "                          mean      mean      mean  \n",
       "InD                                                 \n",
       "cifar10               0.839788  0.902130  0.761394  \n",
       "cifar100              0.676550  0.736675  0.587066  \n",
       "missed_class_cifar10  0.824091  0.884527  0.754405  \n",
       "noisy_cifar10         0.639119  0.696322  0.595381  \n",
       "noisy_cifar100        0.595676  0.649527  0.513263  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Total</th>\n",
       "      <th>Bayes</th>\n",
       "      <th>Excess</th>\n",
       "      <th>Bregman Information</th>\n",
       "      <th>Reverse Bregman Information</th>\n",
       "      <th>Expected Pairwise Bregman Information</th>\n",
       "      <th>Bias</th>\n",
       "      <th>MV</th>\n",
       "      <th>MVBI</th>\n",
       "      <th>BiasBI</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>InD</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>cifar10</th>\n",
       "      <td>0.002214</td>\n",
       "      <td>0.002226</td>\n",
       "      <td>0.004400</td>\n",
       "      <td>0.004495</td>\n",
       "      <td>0.004349</td>\n",
       "      <td>0.004358</td>\n",
       "      <td>0.004504</td>\n",
       "      <td>0.002690</td>\n",
       "      <td>0.004380</td>\n",
       "      <td>0.003959</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cifar100</th>\n",
       "      <td>0.002990</td>\n",
       "      <td>0.002576</td>\n",
       "      <td>0.004303</td>\n",
       "      <td>0.004256</td>\n",
       "      <td>0.004344</td>\n",
       "      <td>0.004310</td>\n",
       "      <td>0.003178</td>\n",
       "      <td>0.004350</td>\n",
       "      <td>0.004996</td>\n",
       "      <td>0.003697</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>missed_class_cifar10</th>\n",
       "      <td>0.040312</td>\n",
       "      <td>0.060736</td>\n",
       "      <td>0.029749</td>\n",
       "      <td>0.032804</td>\n",
       "      <td>0.027790</td>\n",
       "      <td>0.028653</td>\n",
       "      <td>0.019837</td>\n",
       "      <td>0.027850</td>\n",
       "      <td>0.030412</td>\n",
       "      <td>0.030033</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>noisy_cifar10</th>\n",
       "      <td>0.035157</td>\n",
       "      <td>0.033661</td>\n",
       "      <td>0.047991</td>\n",
       "      <td>0.048896</td>\n",
       "      <td>0.047608</td>\n",
       "      <td>0.047469</td>\n",
       "      <td>0.044198</td>\n",
       "      <td>0.041879</td>\n",
       "      <td>0.049477</td>\n",
       "      <td>0.048107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>noisy_cifar100</th>\n",
       "      <td>0.004808</td>\n",
       "      <td>0.004709</td>\n",
       "      <td>0.007178</td>\n",
       "      <td>0.007099</td>\n",
       "      <td>0.007276</td>\n",
       "      <td>0.007158</td>\n",
       "      <td>0.005123</td>\n",
       "      <td>0.005514</td>\n",
       "      <td>0.007216</td>\n",
       "      <td>0.005997</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                         Total     Bayes    Excess Bregman Information  \\\n",
       "                          mean      mean      mean                mean   \n",
       "InD                                                                      \n",
       "cifar10               0.002214  0.002226  0.004400            0.004495   \n",
       "cifar100              0.002990  0.002576  0.004303            0.004256   \n",
       "missed_class_cifar10  0.040312  0.060736  0.029749            0.032804   \n",
       "noisy_cifar10         0.035157  0.033661  0.047991            0.048896   \n",
       "noisy_cifar100        0.004808  0.004709  0.007178            0.007099   \n",
       "\n",
       "                     Reverse Bregman Information  \\\n",
       "                                            mean   \n",
       "InD                                                \n",
       "cifar10                                 0.004349   \n",
       "cifar100                                0.004344   \n",
       "missed_class_cifar10                    0.027790   \n",
       "noisy_cifar10                           0.047608   \n",
       "noisy_cifar100                          0.007276   \n",
       "\n",
       "                     Expected Pairwise Bregman Information      Bias  \\\n",
       "                                                      mean      mean   \n",
       "InD                                                                    \n",
       "cifar10                                           0.004358  0.004504   \n",
       "cifar100                                          0.004310  0.003178   \n",
       "missed_class_cifar10                              0.028653  0.019837   \n",
       "noisy_cifar10                                     0.047469  0.044198   \n",
       "noisy_cifar100                                    0.007158  0.005123   \n",
       "\n",
       "                            MV      MVBI    BiasBI  \n",
       "                          mean      mean      mean  \n",
       "InD                                                 \n",
       "cifar10               0.002690  0.004380  0.003959  \n",
       "cifar100              0.004350  0.004996  0.003697  \n",
       "missed_class_cifar10  0.027850  0.030412  0.030033  \n",
       "noisy_cifar10         0.041879  0.049477  0.048107  \n",
       "noisy_cifar100        0.005514  0.007216  0.005997  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(all_agg_df_mean)\n",
    "display(all_agg_df_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20dfba8c-3693-409e-8782-9fc77c42d9d9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92d1f79f-d33d-4e59-826d-2e37fdc9f377",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e125a7-c7cf-47cc-96ba-7911616493f3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "28c5f176-2f3f-4efd-99d4-b89686df17f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def enhance_latex_table(input_latex):\n",
    "    lines = input_latex.split(\"\\n\")\n",
    "    enhanced_lines = []\n",
    "\n",
    "    for i, line in enumerate(lines):\n",
    "        if \"\\\\toprule\" in line:\n",
    "            # Add multicolumn headers\n",
    "            enhanced_lines.append(line)\n",
    "            enhanced_lines.append(\n",
    "                r\"\\multicolumn{2}{c}{Dataset} & \\multicolumn{5}{c}{Metrics} \\\\\"\n",
    "            )\n",
    "            enhanced_lines.append(r\"\\cmidrule(lr){1-2} \\cmidrule(lr){3-7}\")\n",
    "            continue\n",
    "\n",
    "        # Add row coloring\n",
    "        if \"\\\\midrule\" in line:\n",
    "            enhanced_lines.append(line)\n",
    "            enhanced_lines.append(r\"\\rowcolor{gray!10}\")\n",
    "        elif \"\\\\bottomrule\" in line:\n",
    "            enhanced_lines.append(r\"\\end{tabular}\")\n",
    "        else:\n",
    "            enhanced_lines.append(line)\n",
    "\n",
    "    return \"\\n\".join(enhanced_lines)\n",
    "\n",
    "\n",
    "def get_nice_df(df_):\n",
    "    df_.index = pd.Index(\n",
    "        data=[\n",
    "            \"CIFAR10\",\n",
    "            \"CIFAR100\",\n",
    "            \"Missed class CIFAR10\",\n",
    "            \"Noisy CIFAR10\",\n",
    "            \"Noisy CIFAR100\",\n",
    "        ],\n",
    "        name=\"InD\",\n",
    "    )\n",
    "    # df_.columns = [\n",
    "    #             # 'Bayes',\n",
    "    #             # 'Excess',\n",
    "    #             # 'Total',\n",
    "    #             'Bayes(O)',\n",
    "    #             'Bayes(I)',\n",
    "    #             'Total(O)',\n",
    "    #             'Total(I)',\n",
    "    #             'BI',\n",
    "    #             'RBI',\n",
    "    #             'EPBI',\n",
    "    #             # 'Bias',\n",
    "    #             # 'MV',\n",
    "    #             # 'MVBI',\n",
    "    #             # 'BiasBI',\n",
    "    # ]\n",
    "    df_ = df_[\n",
    "        [\n",
    "            \"Bayes\",\n",
    "            \"Excess\",\n",
    "            \"Total\",\n",
    "            \"Bregman Information\",\n",
    "            \"Reverse Bregman Information\",\n",
    "            \"Expected Pairwise Bregman Information\",\n",
    "        ]\n",
    "    ]\n",
    "    df_ = (100 * df_).round(2)\n",
    "\n",
    "    display(df_)\n",
    "\n",
    "    return df_, df_.to_latex(float_format=\"%.2f\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "1d256bd7-cd07-445a-aa44-6f846260b866",
   "metadata": {},
   "outputs": [],
   "source": [
    "# measures = [\n",
    "#     'Bayes Outer',\n",
    "#     'Bayes Inner',\n",
    "#     'Total Outer',\n",
    "#     'Total Inner',\n",
    "#     'Bregman Information',\n",
    "#     'Reverse Bregman Information',\n",
    "#     'Expected Pairwise Bregman Information']\n",
    "\n",
    "measures = [\n",
    "    \"Bayes\",\n",
    "    \"Excess\",\n",
    "    \"Total\",\n",
    "    \"Bregman Information\",\n",
    "    \"Reverse Bregman Information\",\n",
    "    \"Expected Pairwise Bregman Information\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54eb2fb8-838f-465d-ba8d-0998fc7689b5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "2fb1b47c-3875-4dcd-bb71-2bdbd5b0bf39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Bayes</th>\n",
       "      <th>Excess</th>\n",
       "      <th>Total</th>\n",
       "      <th>Bregman Information</th>\n",
       "      <th>Reverse Bregman Information</th>\n",
       "      <th>Expected Pairwise Bregman Information</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>InD</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>CIFAR10</th>\n",
       "      <td>93.63</td>\n",
       "      <td>93.22</td>\n",
       "      <td>93.88</td>\n",
       "      <td>93.29</td>\n",
       "      <td>93.14</td>\n",
       "      <td>93.22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CIFAR100</th>\n",
       "      <td>86.32</td>\n",
       "      <td>81.71</td>\n",
       "      <td>86.69</td>\n",
       "      <td>82.45</td>\n",
       "      <td>81.02</td>\n",
       "      <td>81.68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Missed class CIFAR10</th>\n",
       "      <td>90.15</td>\n",
       "      <td>91.09</td>\n",
       "      <td>91.31</td>\n",
       "      <td>91.03</td>\n",
       "      <td>91.09</td>\n",
       "      <td>91.15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Noisy CIFAR10</th>\n",
       "      <td>79.60</td>\n",
       "      <td>73.36</td>\n",
       "      <td>79.24</td>\n",
       "      <td>73.76</td>\n",
       "      <td>73.02</td>\n",
       "      <td>73.30</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Noisy CIFAR100</th>\n",
       "      <td>82.81</td>\n",
       "      <td>72.35</td>\n",
       "      <td>82.67</td>\n",
       "      <td>73.10</td>\n",
       "      <td>71.66</td>\n",
       "      <td>72.29</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      Bayes Excess  Total Bregman Information  \\\n",
       "                       mean   mean   mean                mean   \n",
       "InD                                                             \n",
       "CIFAR10               93.63  93.22  93.88               93.29   \n",
       "CIFAR100              86.32  81.71  86.69               82.45   \n",
       "Missed class CIFAR10  90.15  91.09  91.31               91.03   \n",
       "Noisy CIFAR10         79.60  73.36  79.24               73.76   \n",
       "Noisy CIFAR100        82.81  72.35  82.67               73.10   \n",
       "\n",
       "                     Reverse Bregman Information  \\\n",
       "                                            mean   \n",
       "InD                                                \n",
       "CIFAR10                                    93.14   \n",
       "CIFAR100                                   81.02   \n",
       "Missed class CIFAR10                       91.09   \n",
       "Noisy CIFAR10                              73.02   \n",
       "Noisy CIFAR100                             71.66   \n",
       "\n",
       "                     Expected Pairwise Bregman Information  \n",
       "                                                      mean  \n",
       "InD                                                         \n",
       "CIFAR10                                              93.22  \n",
       "CIFAR100                                             81.68  \n",
       "Missed class CIFAR10                                 91.15  \n",
       "Noisy CIFAR10                                        73.30  \n",
       "Noisy CIFAR100                                       72.29  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      "\\multicolumn{2}{c}{Dataset} & \\multicolumn{5}{c}{Metrics} \\\\\n",
      "\\cmidrule(lr){1-2} \\cmidrule(lr){3-7}\n",
      " & Bayes & Excess & Total & Bregman Information & Reverse Bregman Information & Expected Pairwise Bregman Information \\\\\n",
      " & mean & mean & mean & mean & mean & mean \\\\\n",
      "InD &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "\\rowcolor{gray!10}\n",
      "CIFAR10 & 93.63 & 93.22 & 93.88 & 93.29 & 93.14 & 93.22 \\\\\n",
      "CIFAR100 & 86.32 & 81.71 & 86.69 & 82.45 & 81.02 & 81.68 \\\\\n",
      "Missed class CIFAR10 & 90.15 & 91.09 & 91.31 & 91.03 & 91.09 & 91.15 \\\\\n",
      "Noisy CIFAR10 & 79.60 & 73.36 & 79.24 & 73.76 & 73.02 & 73.30 \\\\\n",
      "Noisy CIFAR100 & 82.81 & 72.35 & 82.67 & 73.10 & 71.66 & 72.29 \\\\\n",
      "\\end{tabular}\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "nice_same = get_nice_df(same_agg_df_mean.copy()[measures])\n",
    "print(enhance_latex_table(nice_same[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "a09e3a56-c7db-49f1-9c02-c82d56c20e86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Bayes</th>\n",
       "      <th>Excess</th>\n",
       "      <th>Total</th>\n",
       "      <th>Bregman Information</th>\n",
       "      <th>Reverse Bregman Information</th>\n",
       "      <th>Expected Pairwise Bregman Information</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>InD</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>CIFAR10</th>\n",
       "      <td>0.22</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.21</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CIFAR100</th>\n",
       "      <td>0.22</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.21</td>\n",
       "      <td>0.31</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Missed class CIFAR10</th>\n",
       "      <td>6.00</td>\n",
       "      <td>3.73</td>\n",
       "      <td>4.17</td>\n",
       "      <td>3.92</td>\n",
       "      <td>3.64</td>\n",
       "      <td>3.63</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Noisy CIFAR10</th>\n",
       "      <td>3.42</td>\n",
       "      <td>4.36</td>\n",
       "      <td>3.64</td>\n",
       "      <td>4.39</td>\n",
       "      <td>4.33</td>\n",
       "      <td>4.36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Noisy CIFAR100</th>\n",
       "      <td>0.46</td>\n",
       "      <td>0.73</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.74</td>\n",
       "      <td>0.73</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     Bayes Excess Total Bregman Information  \\\n",
       "                      mean   mean  mean                mean   \n",
       "InD                                                           \n",
       "CIFAR10               0.22   0.25  0.21                0.24   \n",
       "CIFAR100              0.22   0.32  0.21                0.31   \n",
       "Missed class CIFAR10  6.00   3.73  4.17                3.92   \n",
       "Noisy CIFAR10         3.42   4.36  3.64                4.39   \n",
       "Noisy CIFAR100        0.46   0.73  0.48                0.72   \n",
       "\n",
       "                     Reverse Bregman Information  \\\n",
       "                                            mean   \n",
       "InD                                                \n",
       "CIFAR10                                     0.25   \n",
       "CIFAR100                                    0.32   \n",
       "Missed class CIFAR10                        3.64   \n",
       "Noisy CIFAR10                               4.33   \n",
       "Noisy CIFAR100                              0.74   \n",
       "\n",
       "                     Expected Pairwise Bregman Information  \n",
       "                                                      mean  \n",
       "InD                                                         \n",
       "CIFAR10                                               0.25  \n",
       "CIFAR100                                              0.32  \n",
       "Missed class CIFAR10                                  3.63  \n",
       "Noisy CIFAR10                                         4.36  \n",
       "Noisy CIFAR100                                        0.73  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      "\\multicolumn{2}{c}{Dataset} & \\multicolumn{5}{c}{Metrics} \\\\\n",
      "\\cmidrule(lr){1-2} \\cmidrule(lr){3-7}\n",
      " & Bayes & Excess & Total & Bregman Information & Reverse Bregman Information & Expected Pairwise Bregman Information \\\\\n",
      " & mean & mean & mean & mean & mean & mean \\\\\n",
      "InD &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "\\rowcolor{gray!10}\n",
      "CIFAR10 & 0.22 & 0.25 & 0.21 & 0.24 & 0.25 & 0.25 \\\\\n",
      "CIFAR100 & 0.22 & 0.32 & 0.21 & 0.31 & 0.32 & 0.32 \\\\\n",
      "Missed class CIFAR10 & 6.00 & 3.73 & 4.17 & 3.92 & 3.64 & 3.63 \\\\\n",
      "Noisy CIFAR10 & 3.42 & 4.36 & 3.64 & 4.39 & 4.33 & 4.36 \\\\\n",
      "Noisy CIFAR100 & 0.46 & 0.73 & 0.48 & 0.72 & 0.74 & 0.73 \\\\\n",
      "\\end{tabular}\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "nice_same = get_nice_df(same_agg_df_std[measures].copy())\n",
    "enhanced_latex = enhance_latex_table(nice_same[1])\n",
    "print(enhanced_latex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4be7d78e-9397-429a-bd46-37c6bc1135f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39bc6f41-b1b2-4f29-9eab-8e7eebb7559a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "e01352ea-123f-49f2-a5ca-bc2096a03ca2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Bayes</th>\n",
       "      <th>Excess</th>\n",
       "      <th>Total</th>\n",
       "      <th>Bregman Information</th>\n",
       "      <th>Reverse Bregman Information</th>\n",
       "      <th>Expected Pairwise Bregman Information</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>InD</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>CIFAR10</th>\n",
       "      <td>93.42</td>\n",
       "      <td>88.88</td>\n",
       "      <td>93.52</td>\n",
       "      <td>88.82</td>\n",
       "      <td>88.80</td>\n",
       "      <td>89.01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CIFAR100</th>\n",
       "      <td>85.27</td>\n",
       "      <td>70.75</td>\n",
       "      <td>83.51</td>\n",
       "      <td>71.66</td>\n",
       "      <td>69.81</td>\n",
       "      <td>70.78</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Missed class CIFAR10</th>\n",
       "      <td>89.86</td>\n",
       "      <td>87.26</td>\n",
       "      <td>91.10</td>\n",
       "      <td>86.91</td>\n",
       "      <td>87.38</td>\n",
       "      <td>87.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Noisy CIFAR10</th>\n",
       "      <td>78.15</td>\n",
       "      <td>67.77</td>\n",
       "      <td>77.40</td>\n",
       "      <td>67.90</td>\n",
       "      <td>67.39</td>\n",
       "      <td>68.01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Noisy CIFAR100</th>\n",
       "      <td>81.62</td>\n",
       "      <td>61.68</td>\n",
       "      <td>79.61</td>\n",
       "      <td>62.10</td>\n",
       "      <td>61.28</td>\n",
       "      <td>61.65</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      Bayes Excess  Total Bregman Information  \\\n",
       "                       mean   mean   mean                mean   \n",
       "InD                                                             \n",
       "CIFAR10               93.42  88.88  93.52               88.82   \n",
       "CIFAR100              85.27  70.75  83.51               71.66   \n",
       "Missed class CIFAR10  89.86  87.26  91.10               86.91   \n",
       "Noisy CIFAR10         78.15  67.77  77.40               67.90   \n",
       "Noisy CIFAR100        81.62  61.68  79.61               62.10   \n",
       "\n",
       "                     Reverse Bregman Information  \\\n",
       "                                            mean   \n",
       "InD                                                \n",
       "CIFAR10                                    88.80   \n",
       "CIFAR100                                   69.81   \n",
       "Missed class CIFAR10                       87.38   \n",
       "Noisy CIFAR10                              67.39   \n",
       "Noisy CIFAR100                             61.28   \n",
       "\n",
       "                     Expected Pairwise Bregman Information  \n",
       "                                                      mean  \n",
       "InD                                                         \n",
       "CIFAR10                                              89.01  \n",
       "CIFAR100                                             70.78  \n",
       "Missed class CIFAR10                                 87.48  \n",
       "Noisy CIFAR10                                        68.01  \n",
       "Noisy CIFAR100                                       61.65  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      "\\multicolumn{2}{c}{Dataset} & \\multicolumn{5}{c}{Metrics} \\\\\n",
      "\\cmidrule(lr){1-2} \\cmidrule(lr){3-7}\n",
      " & Bayes & Excess & Total & Bregman Information & Reverse Bregman Information & Expected Pairwise Bregman Information \\\\\n",
      " & mean & mean & mean & mean & mean & mean \\\\\n",
      "InD &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "\\rowcolor{gray!10}\n",
      "CIFAR10 & 93.42 & 88.88 & 93.52 & 88.82 & 88.80 & 89.01 \\\\\n",
      "CIFAR100 & 85.27 & 70.75 & 83.51 & 71.66 & 69.81 & 70.78 \\\\\n",
      "Missed class CIFAR10 & 89.86 & 87.26 & 91.10 & 86.91 & 87.38 & 87.48 \\\\\n",
      "Noisy CIFAR10 & 78.15 & 67.77 & 77.40 & 67.90 & 67.39 & 68.01 \\\\\n",
      "Noisy CIFAR100 & 81.62 & 61.68 & 79.61 & 62.10 & 61.28 & 61.65 \\\\\n",
      "\\end{tabular}\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "nice_same = get_nice_df(different_agg_df_mean[measures].copy())\n",
    "enhanced_latex = enhance_latex_table(nice_same[1])\n",
    "print(enhanced_latex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "97fb3e7d-7fc8-4fe2-b791-47fc8440df81",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>Bayes</th>\n",
       "      <th>Excess</th>\n",
       "      <th>Total</th>\n",
       "      <th>Bregman Information</th>\n",
       "      <th>Reverse Bregman Information</th>\n",
       "      <th>Expected Pairwise Bregman Information</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "      <th>mean</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>InD</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>CIFAR10</th>\n",
       "      <td>0.22</td>\n",
       "      <td>0.49</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.50</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CIFAR100</th>\n",
       "      <td>0.27</td>\n",
       "      <td>0.46</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.45</td>\n",
       "      <td>0.46</td>\n",
       "      <td>0.46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Missed class CIFAR10</th>\n",
       "      <td>6.09</td>\n",
       "      <td>2.79</td>\n",
       "      <td>4.00</td>\n",
       "      <td>3.12</td>\n",
       "      <td>2.56</td>\n",
       "      <td>2.67</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Noisy CIFAR10</th>\n",
       "      <td>3.35</td>\n",
       "      <td>4.91</td>\n",
       "      <td>3.48</td>\n",
       "      <td>5.01</td>\n",
       "      <td>4.87</td>\n",
       "      <td>4.84</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Noisy CIFAR100</th>\n",
       "      <td>0.47</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.71</td>\n",
       "      <td>0.73</td>\n",
       "      <td>0.71</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     Bayes Excess Total Bregman Information  \\\n",
       "                      mean   mean  mean                mean   \n",
       "InD                                                           \n",
       "CIFAR10               0.22   0.49  0.22                0.50   \n",
       "CIFAR100              0.27   0.46  0.32                0.45   \n",
       "Missed class CIFAR10  6.09   2.79  4.00                3.12   \n",
       "Noisy CIFAR10         3.35   4.91  3.48                5.01   \n",
       "Noisy CIFAR100        0.47   0.72  0.48                0.71   \n",
       "\n",
       "                     Reverse Bregman Information  \\\n",
       "                                            mean   \n",
       "InD                                                \n",
       "CIFAR10                                     0.48   \n",
       "CIFAR100                                    0.46   \n",
       "Missed class CIFAR10                        2.56   \n",
       "Noisy CIFAR10                               4.87   \n",
       "Noisy CIFAR100                              0.73   \n",
       "\n",
       "                     Expected Pairwise Bregman Information  \n",
       "                                                      mean  \n",
       "InD                                                         \n",
       "CIFAR10                                               0.48  \n",
       "CIFAR100                                              0.46  \n",
       "Missed class CIFAR10                                  2.67  \n",
       "Noisy CIFAR10                                         4.84  \n",
       "Noisy CIFAR100                                        0.71  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      "\\multicolumn{2}{c}{Dataset} & \\multicolumn{5}{c}{Metrics} \\\\\n",
      "\\cmidrule(lr){1-2} \\cmidrule(lr){3-7}\n",
      " & Bayes & Excess & Total & Bregman Information & Reverse Bregman Information & Expected Pairwise Bregman Information \\\\\n",
      " & mean & mean & mean & mean & mean & mean \\\\\n",
      "InD &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "\\rowcolor{gray!10}\n",
      "CIFAR10 & 0.22 & 0.49 & 0.22 & 0.50 & 0.48 & 0.48 \\\\\n",
      "CIFAR100 & 0.27 & 0.46 & 0.32 & 0.45 & 0.46 & 0.46 \\\\\n",
      "Missed class CIFAR10 & 6.09 & 2.79 & 4.00 & 3.12 & 2.56 & 2.67 \\\\\n",
      "Noisy CIFAR10 & 3.35 & 4.91 & 3.48 & 5.01 & 4.87 & 4.84 \\\\\n",
      "Noisy CIFAR100 & 0.47 & 0.72 & 0.48 & 0.71 & 0.73 & 0.71 \\\\\n",
      "\\end{tabular}\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "nice_same = get_nice_df(different_agg_df_std[measures].copy())\n",
    "enhanced_latex = enhance_latex_table(nice_same[1])\n",
    "print(enhanced_latex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "f515a14d-89be-4733-a4cc-5a86c372fa69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['Brier', 'Logscore', 'Neglog', 'Maxprob', 'Spherical'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_mis_rocauc.base_rule.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "089a3497-9589-49c6-9cd2-b57ecbf26299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5611af5a-4b1a-4e1b-8961-1fd2f6951a16",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e93dd55b-1f65-4869-8bc4-9f5c60ae78d3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cca3eb07-96ef-4b76-ae81-098305b40369",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
