{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fdd0aacc-2c95-453d-9bd8-d7f3abb0fee0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "stty: 'standard input': Inappropriate ioctl for device\n"
     ]
    }
   ],
   "source": [
    "from source.source.postprocessing_utils import (\n",
    "    get_sampled_combinations_uncertainty_scores,\n",
    "    get_predicted_labels,\n",
    "    get_missclassification_dataframe,\n",
    "    get_ood_detection_dataframe,\n",
    "    get_raw_scores_dataframe,\n",
    ")\n",
    "from source.datasets.constants import DatasetName\n",
    "from source.losses.constants import LossName\n",
    "from source.models.constants import ModelName, ModelSource\n",
    "from source.metrics import (\n",
    "    ApproximationType,\n",
    "    GName,\n",
    "    RiskType,\n",
    ")\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "from source.source.path_config import REPOSITORY_ROOT\n",
    "import os\n",
    "from source.source.postprocessing_utils import remove_and_expand_list\n",
    "\n",
    "pd.set_option(\"display.max_rows\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "616c37e2-af1d-4198-aee4-4e69d426c12f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "76806699-ee5f-4d66-8b7b-f5f43bf73379",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_source = ModelSource.OUR_MODELS.value\n",
    "\n",
    "training_dataset_names = [\n",
    "    # DatasetName.CIFAR10.value,\n",
    "    # DatasetName.TINY_IMAGENET.value,\n",
    "    # DatasetName.CIFAR100.value,\n",
    "    # DatasetName.CIFAR10_NOISY_LABEL.value,\n",
    "    DatasetName.CIFAR100_NOISY_LABEL.value,\n",
    "]\n",
    "temperature = 1.0\n",
    "model_ids = np.arange(20)\n",
    "if model_source == ModelSource.TORCH_UNCERTAINTY.value:\n",
    "    loss_function_names = [LossName.CROSS_ENTROPY]\n",
    "else:\n",
    "    loss_function_names = [el for el in LossName]\n",
    "architectures = [ModelName.RESNET18]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "75540573-6734-4875-b4d6-7700a0dc5f5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_lists_of_extracted_datasets(training_dataset_name: str) -> tuple[list, list]:\n",
    "    if training_dataset_name == DatasetName.CIFAR10.value:\n",
    "        list_extraction_datasets = [\n",
    "            DatasetName.CIFAR10_NOISY_LABEL.value,\n",
    "            DatasetName.CIFAR10.value,\n",
    "            DatasetName.CIFAR100.value,\n",
    "            DatasetName.SVHN.value,\n",
    "            DatasetName.TINY_IMAGENET.value,\n",
    "            DatasetName.CIFAR10C.value,\n",
    "        ]\n",
    "    elif training_dataset_name == DatasetName.TINY_IMAGENET.value:\n",
    "        list_extraction_datasets = [\n",
    "            DatasetName.TINY_IMAGENET.value,\n",
    "            DatasetName.IMAGENET_A.value,\n",
    "            DatasetName.IMAGENET_R.value,\n",
    "            DatasetName.IMAGENET_O.value,\n",
    "        ]\n",
    "\n",
    "    elif training_dataset_name == DatasetName.CIFAR100.value:\n",
    "        list_extraction_datasets = [\n",
    "            DatasetName.CIFAR10.value,\n",
    "            DatasetName.CIFAR100.value,\n",
    "            DatasetName.SVHN.value,\n",
    "            # DatasetName.CIFAR100C.value,\n",
    "        ]\n",
    "\n",
    "    elif training_dataset_name == DatasetName.CIFAR100_NOISY_LABEL.value:\n",
    "        list_extraction_datasets = [\n",
    "            DatasetName.CIFAR100.value,\n",
    "        ]\n",
    "    elif training_dataset_name == DatasetName.CIFAR10_NOISY_LABEL.value:\n",
    "        list_extraction_datasets = [\n",
    "            DatasetName.CIFAR10.value,\n",
    "        ]\n",
    "        \n",
    "    list_extraction_datasets = remove_and_expand_list(list_extraction_datasets)\n",
    "    list_ood_datasets = [el for el in list_extraction_datasets]\n",
    "    return list_extraction_datasets, list_ood_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "eece4440-bca9-4cbc-986b-3d7d05cac393",
   "metadata": {},
   "outputs": [],
   "source": [
    "def postprocess_tables(\n",
    "    full_ood_rocauc_dataframe_,\n",
    "    full_mis_rocauc_dataframe_,\n",
    "    full_dataframe_,\n",
    "):\n",
    "    pattern_baserule = r\"(LogScore|BrierScore|ZeroOneScore|SphericalScore)\"\n",
    "    pattern_risk = r\"(outer outer|outer inner|outer central|inner outer|inner inner|inner central|central outer|central inner|central central|energy inner|energy outer|outer|inner|central)\"\n",
    "\n",
    "    full_ood_rocauc_dataframe_[\"base_rule\"] = full_ood_rocauc_dataframe_[\n",
    "        \"UQMetric\"\n",
    "    ].str.extract(pattern_baserule)\n",
    "    full_ood_rocauc_dataframe_[\"RiskType\"] = full_ood_rocauc_dataframe_[\n",
    "        \"UQMetric\"\n",
    "    ].str.extract(pattern_risk)\n",
    "\n",
    "    full_mis_rocauc_dataframe_[\"base_rule\"] = full_mis_rocauc_dataframe_[\n",
    "        \"UQMetric\"\n",
    "    ].str.extract(pattern_baserule)\n",
    "    full_mis_rocauc_dataframe_[\"RiskType\"] = full_mis_rocauc_dataframe_[\n",
    "        \"UQMetric\"\n",
    "    ].str.extract(pattern_risk)\n",
    "\n",
    "    full_dataframe_[\"base_rule\"] = full_dataframe_[\"UQMetric\"].str.extract(\n",
    "        pattern_baserule\n",
    "    )\n",
    "    full_dataframe_[\"RiskType\"] = full_dataframe_[\"UQMetric\"].str.extract(pattern_risk)\n",
    "\n",
    "    return full_ood_rocauc_dataframe_, full_mis_rocauc_dataframe_, full_dataframe_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "322cf1db-5211-4d87-8b23-cee5681a45b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a40f2985-866d-470b-80ca-011d717702ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8b8ab8e6b9a4852886e780eb5665ee6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/110 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for training_dataset_name in training_dataset_names:\n",
    "    full_dataframe = None\n",
    "    full_ood_rocauc_dataframe = None\n",
    "    full_mis_rocauc_dataframe = None\n",
    "\n",
    "    list_extraction_datasets, list_ood_datasets = get_lists_of_extracted_datasets(\n",
    "        training_dataset_name\n",
    "    )\n",
    "\n",
    "    if training_dataset_name not in [\n",
    "        DatasetName.CIFAR10_NOISY_LABEL.value,\n",
    "        DatasetName.CIFAR100_NOISY_LABEL.value,\n",
    "    ]:\n",
    "        training_dataset_name_aux = training_dataset_name\n",
    "    else:\n",
    "        training_dataset_name_aux = training_dataset_name.split(\"_\")[0]\n",
    "\n",
    "    for architecture in architectures:\n",
    "        uq_results, embeddings_per_dataset, targets_per_dataset = (\n",
    "            get_sampled_combinations_uncertainty_scores(\n",
    "                loss_function_names=loss_function_names,\n",
    "                training_dataset_name=training_dataset_name,\n",
    "                architecture=architecture,\n",
    "                model_ids=model_ids,\n",
    "                list_extraction_datasets=list_extraction_datasets,\n",
    "                temperature=temperature,\n",
    "                model_source=model_source,\n",
    "                use_cached=False,\n",
    "            )\n",
    "        )\n",
    "\n",
    "        df_ood = get_ood_detection_dataframe(\n",
    "            ind_dataset=training_dataset_name_aux,\n",
    "            uq_results=uq_results,\n",
    "            list_ood_datasets=list_ood_datasets,\n",
    "        )\n",
    "\n",
    "        max_ind = int(\n",
    "            targets_per_dataset[training_dataset_name_aux].shape[0] / len(model_ids)\n",
    "        )\n",
    "        true_labels = targets_per_dataset[training_dataset_name_aux][:max_ind]\n",
    "\n",
    "        pred_labels = get_predicted_labels(\n",
    "            embeddings_per_dataset=embeddings_per_dataset,\n",
    "            training_dataset_name=training_dataset_name_aux,\n",
    "        )\n",
    "\n",
    "        df_misclassification = get_missclassification_dataframe(\n",
    "            ind_dataset=training_dataset_name_aux,\n",
    "            uq_results=uq_results,\n",
    "            true_labels=true_labels,\n",
    "            pred_labels=pred_labels,\n",
    "        )\n",
    "\n",
    "        scores_df_unravel = get_raw_scores_dataframe(uq_results=uq_results)\n",
    "        scores_df_unravel[\"architecture\"] = architecture.value\n",
    "        scores_df_unravel[\"training_dataset\"] = training_dataset_name\n",
    "        df_ood[\"architecture\"] = architecture.value\n",
    "        df_ood[\"training_dataset\"] = training_dataset_name\n",
    "        df_misclassification[\"architecture\"] = architecture.value\n",
    "        df_misclassification[\"training_dataset\"] = training_dataset_name\n",
    "\n",
    "        if full_dataframe is None:\n",
    "            full_dataframe = scores_df_unravel\n",
    "            full_ood_rocauc_dataframe = df_ood\n",
    "            full_mis_rocauc_dataframe = df_misclassification\n",
    "        else:\n",
    "            full_dataframe = pd.concat([full_dataframe, scores_df_unravel])\n",
    "            full_ood_rocauc_dataframe = pd.concat([full_ood_rocauc_dataframe, df_ood])\n",
    "            full_mis_rocauc_dataframe = pd.concat(\n",
    "                [full_mis_rocauc_dataframe, df_misclassification]\n",
    "            )\n",
    "\n",
    "    full_ood_rocauc_dataframe_, full_mis_rocauc_dataframe_, full_dataframe_ = (\n",
    "        postprocess_tables(\n",
    "            full_ood_rocauc_dataframe_=full_ood_rocauc_dataframe,\n",
    "            full_mis_rocauc_dataframe_=full_mis_rocauc_dataframe,\n",
    "            full_dataframe_=full_dataframe,\n",
    "        )\n",
    "    )\n",
    "    prefix = training_dataset_name + \"_\" + model_source + \"_\"\n",
    "    full_dataframe.to_pickle(\n",
    "        os.path.join(\n",
    "            REPOSITORY_ROOT, f\"tables/central_tables/final/{prefix}full_dataframe.pkl\"\n",
    "        )\n",
    "    )\n",
    "    full_ood_rocauc_dataframe.to_pickle(\n",
    "        os.path.join(\n",
    "            REPOSITORY_ROOT, f\"tables/central_tables/final/{prefix}full_ood_rocauc.pkl\"\n",
    "        )\n",
    "    )\n",
    "    full_mis_rocauc_dataframe.to_pickle(\n",
    "        os.path.join(\n",
    "            REPOSITORY_ROOT, f\"tables/central_tables/final/{prefix}full_mis_rocauc.pkl\"\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "48431686-4fb4-4f5a-9eda-5c92532898c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cifar100'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_dataset_name_aux"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8426235-d510-4352-b950-99a9ee1dd984",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e0763a2-14c7-4e79-8304-32ae4e5f84a8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "28dcae15-662a-4dad-8e1d-54119431bbb9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['tiny_imagenet', 'imagenet_a', 'imagenet_r', 'imagenet_o'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "uq_results[\"LogScore energy outer\"][\"CrossEntropy\"].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d261629c-dc39-40de-9697-b02cd8d3f985",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eb922ee-1624-40dd-9079-a44257c5bc83",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c33b06-28ee-4417-bf44-a6a63882705c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "08a23651-82bf-4af0-a654-1238d3dd2b1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UQMetric</th>\n",
       "      <th>LossFunction</th>\n",
       "      <th>Dataset</th>\n",
       "      <th>Scores</th>\n",
       "      <th>architecture</th>\n",
       "      <th>training_dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>980</th>\n",
       "      <td>SphericalScore BayesRisk outer</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>[[0.17789865, 0.7583793, 0.5098355, 0.02870331...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>385</th>\n",
       "      <td>BrierScore ExcessRisk outer central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>imagenet_a</td>\n",
       "      <td>[[0.23723434, 0.03039942, 0.054669026, 0.09136...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>489</th>\n",
       "      <td>BrierScore BayesRisk inner</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>imagenet_a</td>\n",
       "      <td>[[0.65724003, 0.9308876, 0.912091, 0.93532413,...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>LogScore TotalRisk outer outer</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>imagenet_r</td>\n",
       "      <td>[[4.902204, 2.9508655, 1.2938819, 3.739118, 4....</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>323</th>\n",
       "      <td>BrierScore TotalRisk inner central</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>[[0.6010382, 0.431612, 0.1471363, 0.49661836, ...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>LogScore ExcessRisk outer central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>[[0.23100778, 0.56910986, 0.4589969, 0.0679755...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>319</th>\n",
       "      <td>BrierScore TotalRisk inner central</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>[[0.6010382, 0.431612, 0.1471363, 0.49661836, ...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>418</th>\n",
       "      <td>BrierScore ExcessRisk inner inner</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>imagenet_r</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>838</th>\n",
       "      <td>SphericalScore TotalRisk central outer</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>imagenet_r</td>\n",
       "      <td>[[0.9545533436269606, 0.8859911670965148, 0.38...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>919</th>\n",
       "      <td>SphericalScore ExcessRisk inner inner</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   UQMetric    LossFunction        Dataset  \\\n",
       "980          SphericalScore BayesRisk outer  SphericalScore  tiny_imagenet   \n",
       "385     BrierScore ExcessRisk outer central    CrossEntropy     imagenet_a   \n",
       "489              BrierScore BayesRisk inner  SphericalScore     imagenet_a   \n",
       "10           LogScore TotalRisk outer outer  SphericalScore     imagenet_r   \n",
       "323      BrierScore TotalRisk inner central  SphericalScore     imagenet_o   \n",
       "132       LogScore ExcessRisk outer central    CrossEntropy  tiny_imagenet   \n",
       "319      BrierScore TotalRisk inner central      BrierScore     imagenet_o   \n",
       "418       BrierScore ExcessRisk inner inner  SphericalScore     imagenet_r   \n",
       "838  SphericalScore TotalRisk central outer  SphericalScore     imagenet_r   \n",
       "919   SphericalScore ExcessRisk inner inner      BrierScore     imagenet_o   \n",
       "\n",
       "                                                Scores architecture  \\\n",
       "980  [[0.17789865, 0.7583793, 0.5098355, 0.02870331...     resnet18   \n",
       "385  [[0.23723434, 0.03039942, 0.054669026, 0.09136...     resnet18   \n",
       "489  [[0.65724003, 0.9308876, 0.912091, 0.93532413,...     resnet18   \n",
       "10   [[4.902204, 2.9508655, 1.2938819, 3.739118, 4....     resnet18   \n",
       "323  [[0.6010382, 0.431612, 0.1471363, 0.49661836, ...     resnet18   \n",
       "132  [[0.23100778, 0.56910986, 0.4589969, 0.0679755...     resnet18   \n",
       "319  [[0.6010382, 0.431612, 0.1471363, 0.49661836, ...     resnet18   \n",
       "418  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...     resnet18   \n",
       "838  [[0.9545533436269606, 0.8859911670965148, 0.38...     resnet18   \n",
       "919  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...     resnet18   \n",
       "\n",
       "    training_dataset  \n",
       "980    tiny_imagenet  \n",
       "385    tiny_imagenet  \n",
       "489    tiny_imagenet  \n",
       "10     tiny_imagenet  \n",
       "323    tiny_imagenet  \n",
       "132    tiny_imagenet  \n",
       "319    tiny_imagenet  \n",
       "418    tiny_imagenet  \n",
       "838    tiny_imagenet  \n",
       "919    tiny_imagenet  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_dataframe.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00661d1f-8c14-42b8-89f3-ae703f87379f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bdfd4925-f6e6-4b72-957b-f654c2056b01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['tiny_imagenet', 'imagenet_a', 'imagenet_r', 'imagenet_o'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_ood_rocauc_dataframe.Dataset.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c75b53dd-c9f8-4631-9883-56c2d5665491",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "223730f1-946d-4768-97a7-7109877c24c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "pattern_baserule = r\"(LogScore|BrierScore|ZeroOneScore|SphericalScore)\"\n",
    "pattern_risk = r\"(outer outer|outer inner|outer central|inner outer|inner inner|inner central|central outer|central inner|central central|energy inner|energy outer|outer|inner|central)\"\n",
    "\n",
    "full_ood_rocauc_dataframe[\"base_rule\"] = full_ood_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_baserule)\n",
    "full_ood_rocauc_dataframe[\"RiskType\"] = full_ood_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_risk)\n",
    "\n",
    "full_mis_rocauc_dataframe[\"base_rule\"] = full_mis_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_baserule)\n",
    "full_mis_rocauc_dataframe[\"RiskType\"] = full_mis_rocauc_dataframe[\n",
    "    \"UQMetric\"\n",
    "].str.extract(pattern_risk)\n",
    "\n",
    "full_dataframe[\"base_rule\"] = full_dataframe[\"UQMetric\"].str.extract(pattern_baserule)\n",
    "full_dataframe[\"RiskType\"] = full_dataframe[\"UQMetric\"].str.extract(pattern_risk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0ccd9ca-e748-4134-9541-b1c3ced4e5d9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7e01f2e7-6808-4b6f-b1fa-1321e25e747d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UQMetric</th>\n",
       "      <th>LossFunction</th>\n",
       "      <th>Dataset</th>\n",
       "      <th>Scores</th>\n",
       "      <th>architecture</th>\n",
       "      <th>training_dataset</th>\n",
       "      <th>base_rule</th>\n",
       "      <th>RiskType</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>896</th>\n",
       "      <td>SphericalScore ExcessRisk outer central</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>[[0.008429126962601258, 0.08840444922807775, 0...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>outer central</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>433</th>\n",
       "      <td>BrierScore ExcessRisk central outer</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>imagenet_a</td>\n",
       "      <td>[[0.23723432, 0.030399434, 0.054669186, 0.0913...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>central outer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>816</th>\n",
       "      <td>SphericalScore TotalRisk inner central</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>[[0.18632775730031337, 0.8467837293621233, 0.5...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>inner central</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>237</th>\n",
       "      <td>LogScore BayesRisk inner</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>imagenet_a</td>\n",
       "      <td>[[1.6223768, 3.8559604, 3.3577938, 3.7510993, ...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>LogScore</td>\n",
       "      <td>inner</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>400</th>\n",
       "      <td>BrierScore ExcessRisk inner outer</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>[[0.026536442, 0.029567286, 0.060794458, 0.001...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>inner outer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>975</th>\n",
       "      <td>SphericalScore BayesRisk outer</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>[[0.32719886, 0.22051707, 0.0755226, 0.2691950...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>outer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>227</th>\n",
       "      <td>LogScore BayesRisk outer</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>[[1.6531043, 0.9266964, 0.34597713, 1.156508, ...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>LogScore</td>\n",
       "      <td>outer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>626</th>\n",
       "      <td>ZeroOneScore ExcessRisk outer inner</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>imagenet_r</td>\n",
       "      <td>[[0.17329028, 0.091238156, 0.0, 0.073936634, 0...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>ZeroOneScore</td>\n",
       "      <td>outer inner</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>LogScore ExcessRisk outer inner</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>[[0.3859374, 0.15969387, 0.021598311, 0.128690...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>LogScore</td>\n",
       "      <td>outer inner</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>693</th>\n",
       "      <td>ZeroOneScore ExcessRisk central outer</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>imagenet_a</td>\n",
       "      <td>[[0.0090661645, 0.08915358, 0.0853706, 0.07697...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>ZeroOneScore</td>\n",
       "      <td>central outer</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    UQMetric    LossFunction        Dataset  \\\n",
       "896  SphericalScore ExcessRisk outer central  SphericalScore  tiny_imagenet   \n",
       "433      BrierScore ExcessRisk central outer    CrossEntropy     imagenet_a   \n",
       "816   SphericalScore TotalRisk inner central    CrossEntropy  tiny_imagenet   \n",
       "237                 LogScore BayesRisk inner  SphericalScore     imagenet_a   \n",
       "400        BrierScore ExcessRisk inner outer      BrierScore  tiny_imagenet   \n",
       "975           SphericalScore BayesRisk outer    CrossEntropy     imagenet_o   \n",
       "227                 LogScore BayesRisk outer  SphericalScore     imagenet_o   \n",
       "626      ZeroOneScore ExcessRisk outer inner    CrossEntropy     imagenet_r   \n",
       "127          LogScore ExcessRisk outer inner      BrierScore     imagenet_o   \n",
       "693    ZeroOneScore ExcessRisk central outer  SphericalScore     imagenet_a   \n",
       "\n",
       "                                                Scores architecture  \\\n",
       "896  [[0.008429126962601258, 0.08840444922807775, 0...     resnet18   \n",
       "433  [[0.23723432, 0.030399434, 0.054669186, 0.0913...     resnet18   \n",
       "816  [[0.18632775730031337, 0.8467837293621233, 0.5...     resnet18   \n",
       "237  [[1.6223768, 3.8559604, 3.3577938, 3.7510993, ...     resnet18   \n",
       "400  [[0.026536442, 0.029567286, 0.060794458, 0.001...     resnet18   \n",
       "975  [[0.32719886, 0.22051707, 0.0755226, 0.2691950...     resnet18   \n",
       "227  [[1.6531043, 0.9266964, 0.34597713, 1.156508, ...     resnet18   \n",
       "626  [[0.17329028, 0.091238156, 0.0, 0.073936634, 0...     resnet18   \n",
       "127  [[0.3859374, 0.15969387, 0.021598311, 0.128690...     resnet18   \n",
       "693  [[0.0090661645, 0.08915358, 0.0853706, 0.07697...     resnet18   \n",
       "\n",
       "    training_dataset       base_rule       RiskType  \n",
       "896    tiny_imagenet  SphericalScore  outer central  \n",
       "433    tiny_imagenet      BrierScore  central outer  \n",
       "816    tiny_imagenet  SphericalScore  inner central  \n",
       "237    tiny_imagenet        LogScore          inner  \n",
       "400    tiny_imagenet      BrierScore    inner outer  \n",
       "975    tiny_imagenet  SphericalScore          outer  \n",
       "227    tiny_imagenet        LogScore          outer  \n",
       "626    tiny_imagenet    ZeroOneScore    outer inner  \n",
       "127    tiny_imagenet        LogScore    outer inner  \n",
       "693    tiny_imagenet    ZeroOneScore  central outer  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_dataframe.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "827d1cbf-9062-4578-8972-894698f17c49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UQMetric</th>\n",
       "      <th>Dataset</th>\n",
       "      <th>LossFunction</th>\n",
       "      <th>RocAucScores_array</th>\n",
       "      <th>architecture</th>\n",
       "      <th>training_dataset</th>\n",
       "      <th>base_rule</th>\n",
       "      <th>RiskType</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>787</th>\n",
       "      <td>SphericalScore TotalRisk outer central</td>\n",
       "      <td>imagenet_r</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>[0.8217763066666666, 0.82397291, 0.8284398, 0....</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>outer central</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>303</th>\n",
       "      <td>BrierScore TotalRisk inner inner</td>\n",
       "      <td>imagenet_a</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.8298071066666667, 0.8325714399999999, 0.835...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>inner inner</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>475</th>\n",
       "      <td>BrierScore BayesRisk outer</td>\n",
       "      <td>imagenet_r</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>[0.812054965, 0.8143152816666668, 0.8197883766...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>outer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>526</th>\n",
       "      <td>ZeroOneScore TotalRisk outer inner</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>[0.729591725, 0.73271675, 0.7338373250000001, ...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>ZeroOneScore</td>\n",
       "      <td>outer inner</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>887</th>\n",
       "      <td>SphericalScore ExcessRisk outer inner</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>[0.717291325, 0.7167095500000001, 0.7198998750...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>outer inner</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>411</th>\n",
       "      <td>BrierScore ExcessRisk inner inner</td>\n",
       "      <td>imagenet_a</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.5]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>inner inner</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>255</th>\n",
       "      <td>BrierScore TotalRisk outer outer</td>\n",
       "      <td>imagenet_a</td>\n",
       "      <td>CrossEntropy</td>\n",
       "      <td>[0.8154712466666668, 0.8150871133333334, 0.821...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>outer outer</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>421</th>\n",
       "      <td>BrierScore ExcessRisk inner central</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.5]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>BrierScore</td>\n",
       "      <td>inner central</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>563</th>\n",
       "      <td>ZeroOneScore TotalRisk inner inner</td>\n",
       "      <td>imagenet_o</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>[0.729591675, 0.732716775, 0.733837125, 0.7326...</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>ZeroOneScore</td>\n",
       "      <td>inner inner</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>164</th>\n",
       "      <td>LogScore ExcessRisk inner inner</td>\n",
       "      <td>imagenet_r</td>\n",
       "      <td>SphericalScore</td>\n",
       "      <td>[0.5, 0.5, 0.5, 0.5, 0.5]</td>\n",
       "      <td>resnet18</td>\n",
       "      <td>tiny_imagenet</td>\n",
       "      <td>LogScore</td>\n",
       "      <td>inner inner</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   UQMetric        Dataset    LossFunction  \\\n",
       "787  SphericalScore TotalRisk outer central     imagenet_r      BrierScore   \n",
       "303        BrierScore TotalRisk inner inner     imagenet_a    CrossEntropy   \n",
       "475              BrierScore BayesRisk outer     imagenet_r      BrierScore   \n",
       "526      ZeroOneScore TotalRisk outer inner     imagenet_o      BrierScore   \n",
       "887   SphericalScore ExcessRisk outer inner     imagenet_o  SphericalScore   \n",
       "411       BrierScore ExcessRisk inner inner     imagenet_a    CrossEntropy   \n",
       "255        BrierScore TotalRisk outer outer     imagenet_a    CrossEntropy   \n",
       "421     BrierScore ExcessRisk inner central  tiny_imagenet      BrierScore   \n",
       "563      ZeroOneScore TotalRisk inner inner     imagenet_o  SphericalScore   \n",
       "164         LogScore ExcessRisk inner inner     imagenet_r  SphericalScore   \n",
       "\n",
       "                                    RocAucScores_array architecture  \\\n",
       "787  [0.8217763066666666, 0.82397291, 0.8284398, 0....     resnet18   \n",
       "303  [0.8298071066666667, 0.8325714399999999, 0.835...     resnet18   \n",
       "475  [0.812054965, 0.8143152816666668, 0.8197883766...     resnet18   \n",
       "526  [0.729591725, 0.73271675, 0.7338373250000001, ...     resnet18   \n",
       "887  [0.717291325, 0.7167095500000001, 0.7198998750...     resnet18   \n",
       "411                          [0.5, 0.5, 0.5, 0.5, 0.5]     resnet18   \n",
       "255  [0.8154712466666668, 0.8150871133333334, 0.821...     resnet18   \n",
       "421                          [0.5, 0.5, 0.5, 0.5, 0.5]     resnet18   \n",
       "563  [0.729591675, 0.732716775, 0.733837125, 0.7326...     resnet18   \n",
       "164                          [0.5, 0.5, 0.5, 0.5, 0.5]     resnet18   \n",
       "\n",
       "    training_dataset       base_rule       RiskType  \n",
       "787    tiny_imagenet  SphericalScore  outer central  \n",
       "303    tiny_imagenet      BrierScore    inner inner  \n",
       "475    tiny_imagenet      BrierScore          outer  \n",
       "526    tiny_imagenet    ZeroOneScore    outer inner  \n",
       "887    tiny_imagenet  SphericalScore    outer inner  \n",
       "411    tiny_imagenet      BrierScore    inner inner  \n",
       "255    tiny_imagenet      BrierScore    outer outer  \n",
       "421    tiny_imagenet      BrierScore  inner central  \n",
       "563    tiny_imagenet    ZeroOneScore    inner inner  \n",
       "164    tiny_imagenet        LogScore    inner inner  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_ood_rocauc_dataframe.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "708f67f2-e478-4a58-a4c2-6ef28947a5a6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b48d17a4-88cc-477b-a6fd-c315de7052bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# full_dataframe.to_csv('../../tables/central_tables/full_dataframe.csv')\n",
    "# full_ood_rocauc_dataframe.to_csv('../../tables/central_tables/full_ood_rocauc.csv')\n",
    "# full_mis_rocauc_dataframe.to_csv('../../tables/central_tables/full_mis_rocauc.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "06503d54-1a98-4204-9565-d18e9580973e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(training_dataset_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d217b1b0-0620-42b9-933b-d06cc00099f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "if (\n",
    "    len(training_dataset_names) == 1\n",
    "    and training_dataset_names[0] == DatasetName.TINY_IMAGENET.value\n",
    "):\n",
    "    prefix = \"imagenet_\"\n",
    "else:\n",
    "    prefix = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "395c32b1-8787-4b15-9127-d3728abcfc51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'imagenet_'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prefix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3ea43cfa-c155-48b9-ac9b-c2f39a67d8ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_dataframe.to_pickle(\n",
    "    os.path.join(REPOSITORY_ROOT, f\"tables/central_tables/{prefix}full_dataframe.pkl\")\n",
    ")\n",
    "full_ood_rocauc_dataframe.to_pickle(\n",
    "    os.path.join(REPOSITORY_ROOT, f\"tables/central_tables/{prefix}full_ood_rocauc.pkl\")\n",
    ")\n",
    "full_mis_rocauc_dataframe.to_pickle(\n",
    "    os.path.join(REPOSITORY_ROOT, f\"tables/central_tables/{prefix}full_mis_rocauc.pkl\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6a824609-f7f3-40d3-940b-766b0eb2cb33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# full_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, \"tables/central_tables/full_dataframe.csv\"))\n",
    "# full_ood_rocauc_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, \"tables/central_tables/full_ood_rocauc.csv\"))\n",
    "# full_mis_rocauc_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, \"tables/central_tables/full_mis_rocauc.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "048dce87-7d6b-4403-94d9-9f5f5a318a77",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
