{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4bc24587-c7d7-4004-b06c-36c7ad98c520",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "stty: 'standard input': Inappropriate ioctl for device\n",
      "/home/nkotelevskii/github/uncertainty_from_proper_scoring_rules/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from source.source.postprocessing_utils import (\n",
    "    get_predicted_labels,\n",
    "    get_missclassification_dataframe,\n",
    "    get_ood_detection_dataframe,\n",
    "    get_raw_scores_dataframe,\n",
    ")\n",
    "\n",
    "from source.datasets.constants import DatasetName\n",
    "from source.losses.constants import LossName\n",
    "from source.models.constants import ModelName\n",
    "from source.metrics import (\n",
    "    ApproximationType,\n",
    "    GName,\n",
    "    RiskType,\n",
    ")\n",
    "from torch_uncertainty_models.source.notebook_utils import (\n",
    "    get_new_models_sampled_combinations_uncertainty_scores,\n",
    ")\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "pd.set_option(\"display.max_rows\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "89e7c74b-fc80-46a5-9ba8-154e4b84d8e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature = 1.0\n",
    "model_ids = np.arange(20)\n",
    "\n",
    "loss_function_names = [LossName.CROSS_ENTROPY]\n",
    "training_dataset_names = [DatasetName.CIFAR10.value, DatasetName.CIFAR100.value]\n",
    "\n",
    "full_dataframe = None\n",
    "full_ood_rocauc_dataframe = None\n",
    "full_mis_rocauc_dataframe = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68c6ed01-7eb0-43e8-8e08-803d36b8e98d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a926f6f-6e0f-42af-97ed-27abc5941f05",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d4d6324-54b7-4d3a-adab-e6053bfadbaa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 110/110 [09:59<00:00,  5.45s/it]\n"
     ]
    }
   ],
   "source": [
    "for training_dataset_name in training_dataset_names:\n",
    "    if training_dataset_name == DatasetName.CIFAR10.value:\n",
    "        list_extraction_datasets = [\n",
    "            DatasetName.CIFAR10.value,\n",
    "            DatasetName.CIFAR100.value,\n",
    "            DatasetName.CIFAR10C.value,\n",
    "            DatasetName.TINY_IMAGENET.value,\n",
    "        ]\n",
    "    elif training_dataset_name == DatasetName.CIFAR100.value:\n",
    "        list_extraction_datasets = [\n",
    "            DatasetName.CIFAR10.value,\n",
    "            DatasetName.CIFAR100.value,\n",
    "            # DatasetName.CIFAR100C.value,\n",
    "            DatasetName.TINY_IMAGENET.value,\n",
    "        ]\n",
    "    else:\n",
    "        raise NotImplementedError(\"Need to implement\")\n",
    "\n",
    "    if DatasetName.CIFAR10C.value in list_extraction_datasets:\n",
    "        list_extraction_datasets.remove(DatasetName.CIFAR10C.value)\n",
    "        list_extraction_datasets.extend(\n",
    "            [DatasetName.CIFAR10C.value + f\"_{i}\" for i in range(1, 6)]\n",
    "        )\n",
    "    if DatasetName.CIFAR100C.value in list_extraction_datasets:\n",
    "        list_extraction_datasets.remove(DatasetName.CIFAR100C.value)\n",
    "        list_extraction_datasets.extend(\n",
    "            [DatasetName.CIFAR100C.value + f\"_{i}\" for i in range(1, 6)]\n",
    "        )\n",
    "\n",
    "    architecture = ModelName.RESNET18\n",
    "\n",
    "    if training_dataset_name not in [\n",
    "        \"noisy_cifar10\",\n",
    "        \"noisy_cifar100\",\n",
    "    ]:\n",
    "        training_dataset_name_aux = training_dataset_name\n",
    "    else:\n",
    "        training_dataset_name_aux = training_dataset_name.split(\"_\")[-1]\n",
    "\n",
    "    uq_results, embeddings_per_dataset, targets_per_dataset = (\n",
    "        get_new_models_sampled_combinations_uncertainty_scores(\n",
    "            loss_function_names=loss_function_names,\n",
    "            training_dataset_name=training_dataset_name,\n",
    "            model_ids=model_ids,\n",
    "            list_extraction_datasets=list_extraction_datasets,\n",
    "            temperature=temperature,\n",
    "            use_cached=True,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    df_ood = get_ood_detection_dataframe(\n",
    "        ind_dataset=training_dataset_name_aux,\n",
    "        uq_results=uq_results,\n",
    "        list_ood_datasets=list_extraction_datasets,\n",
    "    )\n",
    "\n",
    "    max_ind = int(\n",
    "        targets_per_dataset[training_dataset_name_aux].shape[0] / len(model_ids)\n",
    "    )\n",
    "    true_labels = targets_per_dataset[training_dataset_name_aux][:max_ind]\n",
    "\n",
    "    pred_labels = get_predicted_labels(\n",
    "        embeddings_per_dataset=embeddings_per_dataset,\n",
    "        training_dataset_name=training_dataset_name_aux,\n",
    "    )\n",
    "\n",
    "    df_misclassification = get_missclassification_dataframe(\n",
    "        ind_dataset=training_dataset_name_aux,\n",
    "        uq_results=uq_results,\n",
    "        true_labels=true_labels,\n",
    "        pred_labels=pred_labels,\n",
    "    )\n",
    "\n",
    "    scores_df_unravel = get_raw_scores_dataframe(uq_results=uq_results)\n",
    "    scores_df_unravel[\"architecture\"] = architecture.value\n",
    "    scores_df_unravel[\"training_dataset\"] = training_dataset_name\n",
    "    df_ood[\"architecture\"] = architecture.value\n",
    "    df_ood[\"training_dataset\"] = training_dataset_name\n",
    "    df_misclassification[\"architecture\"] = architecture.value\n",
    "    df_misclassification[\"training_dataset\"] = training_dataset_name\n",
    "\n",
    "    if full_dataframe is None:\n",
    "        full_dataframe = scores_df_unravel\n",
    "        full_ood_rocauc_dataframe = df_ood\n",
    "        full_mis_rocauc_dataframe = df_misclassification\n",
    "    else:\n",
    "        full_dataframe = pd.concat([full_dataframe, scores_df_unravel])\n",
    "        full_ood_rocauc_dataframe = pd.concat([full_ood_rocauc_dataframe, df_ood])\n",
    "        full_mis_rocauc_dataframe = pd.concat(\n",
    "            [full_mis_rocauc_dataframe, df_misclassification]\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7f328db2-0dc2-449d-8c2f-7aad705a75e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['cifar10', 'cifar100', 'tiny_imagenet', 'cifar10c_1', 'cifar10c_2', 'cifar10c_3', 'cifar10c_4', 'cifar10c_5'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "uq_results[\"LogScore energy inner\"][\"CrossEntropy\"].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1d04671-005d-4b65-9a4c-fc4a67264439",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d215ede3-5e9e-4c08-9496-dcd227cdc232",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UQMetric</th>\n",
       "      <th>LossFunction</th>\n",
       "      <th>Dataset</th>\n",
       "      <th>Scores</th>\n",
       "      <th>architecture</th>\n",
       "      <th>training_dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>232</th>\n",
       "      <td>BrierScore TotalRisk central central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>[[0.40422922, 0.75544775, 0.75003016, 0.775929...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551</th>\n",
       "      <td>SphericalScore TotalRisk inner central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar10c_5</td>\n",
       "      <td>[[0.5503747513506348, 0.5926362846990155, 0.50...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>303</th>\n",
       "      <td>BrierScore ExcessRisk central inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar10c_5</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>290</th>\n",
       "      <td>BrierScore ExcessRisk central outer</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>[[0.40737048, 0.44480655, 0.4267369, 0.5548498...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>303</th>\n",
       "      <td>BrierScore ExcessRisk central inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar10c_5</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>230</th>\n",
       "      <td>BrierScore TotalRisk central inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar10c_4</td>\n",
       "      <td>[[0.63418686, 0.7032609, 0.69550395, 0.6251498...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>227</th>\n",
       "      <td>BrierScore TotalRisk central inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar10c_1</td>\n",
       "      <td>[[0.6454982, 0.64261407, 0.6028866, 0.7708387,...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>519</th>\n",
       "      <td>SphericalScore TotalRisk outer inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar10c_5</td>\n",
       "      <td>[[0.4969115, 0.55103815, 0.45745933, 0.4439876...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>381</th>\n",
       "      <td>ZeroOneScore TotalRisk inner central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar10c_3</td>\n",
       "      <td>[[0.50417453, 0.50004387, 0.43402344, 0.575773...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>585</th>\n",
       "      <td>SphericalScore ExcessRisk outer inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>[[0.23793499, 0.37648946, 0.22644593, 0.339729...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   UQMetric  LossFunction        Dataset  \\\n",
       "232    BrierScore TotalRisk central central  CrossEntropy        cifar10   \n",
       "551  SphericalScore TotalRisk inner central  CrossEntropy     cifar10c_5   \n",
       "303     BrierScore ExcessRisk central inner  CrossEntropy     cifar10c_5   \n",
       "290     BrierScore ExcessRisk central outer  CrossEntropy  tiny_imagenet   \n",
       "303     BrierScore ExcessRisk central inner  CrossEntropy     cifar10c_5   \n",
       "230      BrierScore TotalRisk central inner  CrossEntropy     cifar10c_4   \n",
       "227      BrierScore TotalRisk central inner  CrossEntropy     cifar10c_1   \n",
       "519    SphericalScore TotalRisk outer inner  CrossEntropy     cifar10c_5   \n",
       "381    ZeroOneScore TotalRisk inner central  CrossEntropy     cifar10c_3   \n",
       "585   SphericalScore ExcessRisk outer inner  CrossEntropy       cifar100   \n",
       "\n",
       "                                                Scores architecture  \\\n",
       "232  [[0.40422922, 0.75544775, 0.75003016, 0.775929...     resnet18   \n",
       "551  [[0.5503747513506348, 0.5926362846990155, 0.50...     resnet18   \n",
       "303  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...     resnet18   \n",
       "290  [[0.40737048, 0.44480655, 0.4267369, 0.5548498...     resnet18   \n",
       "303  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...     resnet18   \n",
       "230  [[0.63418686, 0.7032609, 0.69550395, 0.6251498...     resnet18   \n",
       "227  [[0.6454982, 0.64261407, 0.6028866, 0.7708387,...     resnet18   \n",
       "519  [[0.4969115, 0.55103815, 0.45745933, 0.4439876...     resnet18   \n",
       "381  [[0.50417453, 0.50004387, 0.43402344, 0.575773...     resnet18   \n",
       "585  [[0.23793499, 0.37648946, 0.22644593, 0.339729...     resnet18   \n",
       "\n",
       "    training_dataset  \n",
       "232         cifar100  \n",
       "551          cifar10  \n",
       "303         cifar100  \n",
       "290         cifar100  \n",
       "303          cifar10  \n",
       "230         cifar100  \n",
       "227          cifar10  \n",
       "519         cifar100  \n",
       "381         cifar100  \n",
       "585         cifar100  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_dataframe.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b5f3ecb6-98c5-4f4e-b740-28a8ab29f9ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UQMetric</th>\n",
       "      <th>Dataset</th>\n",
       "      <th>LossFunction</th>\n",
       "      <th>RocAucScores_array</th>\n",
       "      <th>architecture</th>\n",
       "      <th>training_dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>100</th>\n",
       "      <td>BrierScore ExcessRisk inner outer</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.49999999999999994, 0.49999999999999994, 0.4...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>183</th>\n",
       "      <td>ZeroOneScore BayesRisk inner</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.53501334, 0.524730345, 0.525307985, 0.52766...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>401</th>\n",
       "      <td>ZeroOneScore TotalRisk central central</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.46498666, 0.47526965499999996, 0.4746920149...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>680</th>\n",
       "      <td>LogScore energy inner</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.5]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>180</th>\n",
       "      <td>ZeroOneScore BayesRisk outer</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.067820305, 0.06909691999999999, 0.066266789...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>301</th>\n",
       "      <td>BrierScore ExcessRisk central inner</td>\n",
       "      <td>cifar10c_3</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.5]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>206</th>\n",
       "      <td>SphericalScore TotalRisk inner central</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.51673053, 0.5150643, 0.5025341400000001, 0....</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>172</th>\n",
       "      <td>ZeroOneScore ExcessRisk central outer</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5000000000000001, 0.5]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>LogScore TotalRisk inner inner</td>\n",
       "      <td>cifar10</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.4999999999999999]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>112</th>\n",
       "      <td>BrierScore ExcessRisk central inner</td>\n",
       "      <td>cifar100</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.5]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   UQMetric        Dataset  LossFunction  \\\n",
       "100       BrierScore ExcessRisk inner outer       cifar100  CrossEntropy   \n",
       "183            ZeroOneScore BayesRisk inner        cifar10  CrossEntropy   \n",
       "401  ZeroOneScore TotalRisk central central       cifar100  CrossEntropy   \n",
       "680                   LogScore energy inner        cifar10  CrossEntropy   \n",
       "180            ZeroOneScore BayesRisk outer        cifar10  CrossEntropy   \n",
       "301     BrierScore ExcessRisk central inner     cifar10c_3  CrossEntropy   \n",
       "206  SphericalScore TotalRisk inner central  tiny_imagenet  CrossEntropy   \n",
       "172   ZeroOneScore ExcessRisk central outer       cifar100  CrossEntropy   \n",
       "32           LogScore TotalRisk inner inner        cifar10  CrossEntropy   \n",
       "112     BrierScore ExcessRisk central inner       cifar100  CrossEntropy   \n",
       "\n",
       "                                    RocAucScores_array architecture  \\\n",
       "100  [0.49999999999999994, 0.49999999999999994, 0.4...     resnet18   \n",
       "183  [0.53501334, 0.524730345, 0.525307985, 0.52766...     resnet18   \n",
       "401  [0.46498666, 0.47526965499999996, 0.4746920149...     resnet18   \n",
       "680                          [0.5, 0.5, 0.5, 0.5, 0.5]     resnet18   \n",
       "180  [0.067820305, 0.06909691999999999, 0.066266789...     resnet18   \n",
       "301                          [0.5, 0.5, 0.5, 0.5, 0.5]     resnet18   \n",
       "206  [0.51673053, 0.5150643, 0.5025341400000001, 0....     resnet18   \n",
       "172           [0.5, 0.5, 0.5, 0.5000000000000001, 0.5]     resnet18   \n",
       "32            [0.5, 0.5, 0.5, 0.5, 0.4999999999999999]     resnet18   \n",
       "112                          [0.5, 0.5, 0.5, 0.5, 0.5]     resnet18   \n",
       "\n",
       "    training_dataset  \n",
       "100         cifar100  \n",
       "183         cifar100  \n",
       "401          cifar10  \n",
       "680          cifar10  \n",
       "180         cifar100  \n",
       "301          cifar10  \n",
       "206         cifar100  \n",
       "172         cifar100  \n",
       "32           cifar10  \n",
       "112         cifar100  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_ood_rocauc_dataframe.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4152d6c0-1463-486e-b3e7-d82570678ed9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UQMetric</th>\n",
       "      <th>LossFunction</th>\n",
       "      <th>RocAucScores_array</th>\n",
       "      <th>architecture</th>\n",
       "      <th>training_dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>BrierScore BayesRisk central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.4839095403184386, 0.42035389201542195, 0.51...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83</th>\n",
       "      <td>SphericalScore BayesRisk central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.49742277496514775, 0.44295383527143833, 0.5...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>76</th>\n",
       "      <td>SphericalScore ExcessRisk inner inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.5]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>ZeroOneScore TotalRisk inner inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5094209727419473, 0.41703767256389535, 0.50...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>ZeroOneScore TotalRisk outer outer</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.6407233994702266, 0.5109178346306207, 0.491...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>BrierScore TotalRisk inner central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.6431605422166953, 0.5070854470725878, 0.488...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>82</th>\n",
       "      <td>SphericalScore BayesRisk inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.6431591969205616, 0.5070844336085581, 0.488...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>84</th>\n",
       "      <td>LogScore energy outer</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5128014963451294, 0.510691458770523, 0.5072...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>LogScore TotalRisk outer central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.4549673490351457, 0.45748184969840183, 0.50...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>BrierScore BayesRisk outer</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5130052980324211, 0.5084377281294196, 0.500...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>cifar10</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 UQMetric  LossFunction  \\\n",
       "41           BrierScore BayesRisk central  CrossEntropy   \n",
       "83       SphericalScore BayesRisk central  CrossEntropy   \n",
       "76  SphericalScore ExcessRisk inner inner  CrossEntropy   \n",
       "46     ZeroOneScore TotalRisk inner inner  CrossEntropy   \n",
       "42     ZeroOneScore TotalRisk outer outer  CrossEntropy   \n",
       "26     BrierScore TotalRisk inner central  CrossEntropy   \n",
       "82         SphericalScore BayesRisk inner  CrossEntropy   \n",
       "84                  LogScore energy outer  CrossEntropy   \n",
       "2        LogScore TotalRisk outer central  CrossEntropy   \n",
       "39             BrierScore BayesRisk outer  CrossEntropy   \n",
       "\n",
       "                                   RocAucScores_array architecture  \\\n",
       "41  [0.4839095403184386, 0.42035389201542195, 0.51...     resnet18   \n",
       "83  [0.49742277496514775, 0.44295383527143833, 0.5...     resnet18   \n",
       "76                          [0.5, 0.5, 0.5, 0.5, 0.5]     resnet18   \n",
       "46  [0.5094209727419473, 0.41703767256389535, 0.50...     resnet18   \n",
       "42  [0.6407233994702266, 0.5109178346306207, 0.491...     resnet18   \n",
       "26  [0.6431605422166953, 0.5070854470725878, 0.488...     resnet18   \n",
       "82  [0.6431591969205616, 0.5070844336085581, 0.488...     resnet18   \n",
       "84  [0.5128014963451294, 0.510691458770523, 0.5072...     resnet18   \n",
       "2   [0.4549673490351457, 0.45748184969840183, 0.50...     resnet18   \n",
       "39  [0.5130052980324211, 0.5084377281294196, 0.500...     resnet18   \n",
       "\n",
       "   training_dataset  \n",
       "41         cifar100  \n",
       "83         cifar100  \n",
       "76          cifar10  \n",
       "46         cifar100  \n",
       "42          cifar10  \n",
       "26          cifar10  \n",
       "82          cifar10  \n",
       "84          cifar10  \n",
       "2          cifar100  \n",
       "39          cifar10  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_mis_rocauc_dataframe.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "78bfc94f-510f-45c5-b9fb-2dafafdb677d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pattern_baserule = r\"(LogScore|BrierScore|ZeroOneScore|SphericalScore)\"\n",
    "pattern_risk = r\"(outer outer|outer inner|outer central|inner outer|inner inner|inner central|central outer|central inner|central central|energy inner|energy outer|outer|inner|central)\"\n",
    "\n",
    "full_ood_rocauc_dataframe[\"base_rule\"] = full_ood_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_baserule)\n",
    "full_ood_rocauc_dataframe[\"RiskType\"] = full_ood_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_risk)\n",
    "\n",
    "full_mis_rocauc_dataframe[\"base_rule\"] = full_mis_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_baserule)\n",
    "full_mis_rocauc_dataframe[\"RiskType\"] = full_mis_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_risk)\n",
    "\n",
    "full_dataframe[\"base_rule\"] = full_dataframe[\"UQMetric\"].str.extract(pattern_baserule)\n",
    "full_dataframe[\"RiskType\"] = full_dataframe[\"UQMetric\"].str.extract(pattern_risk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "058bdbf1-9ece-4a24-8868-2fa3c8b3d570",
   "metadata": {},
   "outputs": [],
   "source": [
    "from source.source.path_config import REPOSITORY_ROOT\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1d91dd24-cc1c-441e-b2df-48723d41cbeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_dataframe.to_pickle(\n",
    "    os.path.join(REPOSITORY_ROOT, \"tables/central_tables/new_models_full_dataframe.pkl\")\n",
    ")\n",
    "full_ood_rocauc_dataframe.to_pickle(\n",
    "    os.path.join(\n",
    "        REPOSITORY_ROOT, \"tables/central_tables/new_models_full_ood_rocauc.pkl\"\n",
    "    )\n",
    ")\n",
    "full_mis_rocauc_dataframe.to_pickle(\n",
    "    os.path.join(\n",
    "        REPOSITORY_ROOT, \"tables/central_tables/new_models_full_mis_rocauc.pkl\"\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "428e1ea2-58f1-41d8-875e-eec76ea2029a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
