{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a30ffccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import torch\n",
    "\n",
    "from source.utils.uncertainty_measures import (\n",
    "        calculate_uncertainties_crps,\n",
    "        calculate_uncertainties_log,\n",
    "        calculate_uncertainties_mse,\n",
    "        calculate_uncertainties_quadratic\n",
    ")\n",
    "\n",
    "from source.constants import RESULTS_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "baa77f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"git_qood\"\n",
    "scoring_rules = [\"crps\", \"log\", \"mse\", \"quadratic\"]\n",
    "method_seeds = [42, 142, 242]\n",
    "\n",
    "method = \"ensemble\"\n",
    "anti_regularization_weight = 0.0\n",
    "use_natural = True\n",
    "network = \"cnn\"\n",
    "\n",
    "subset_picker = [\n",
    "    ['test', 'cifar'],\n",
    "    ['test', 'svhn'],\n",
    "    ['test', 'fashion'],\n",
    "    ['test', 'mixture'],\n",
    "    ['test', 'emnist_0'],\n",
    "    ['test', 'emnist_1'],\n",
    "    ['test', 'emnist_2'],\n",
    "    ['test', 'emnist_3']\n",
    "]\n",
    "\n",
    "not_considered = [\n",
    "    (\"log\", \"total_3a_2\"),\n",
    "    (\"log\", \"total_3b_2\"),\n",
    "    (\"log\", \"bayes_2\"),\n",
    "    (\"log\", \"excess_3a_2\"),\n",
    "    (\"log\", \"excess_3b_2\"),\n",
    "    (\"mse\", \"excess_3a_2\"),\n",
    "    (\"mse\", \"excess_3b_2\"),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eafb87a",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = {\n",
    "            \"total_1_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{1,1}$\",\n",
    "            \"total_2_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{2,1}$\",\n",
    "            \"total_3a_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,1}$\",\n",
    "            \"total_3b_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,1}$\",\n",
    "            \"total_3a_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,2}$\",\n",
    "            \"total_3b_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,2}$\",\n",
    "            \"bayes_1\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{1}$\",\n",
    "            \"bayes_2\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{2}$\",\n",
    "            \"bayes_3a\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3a}$\",\n",
    "            \"bayes_3b\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3b}$\",\n",
    "            \"excess_1_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{1,1}$\",\n",
    "            \"excess_2_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{2,1}$\",\n",
    "            \"excess_3a_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,1}$\",\n",
    "            \"excess_3b_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,1}$\",\n",
    "            \"excess_3a_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,2}$\",\n",
    "            \"excess_3b_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,2}$\",\n",
    "        }\n",
    "\n",
    "sr_labels = {\n",
    "            \"crps\": \"CRPS\",\n",
    "            \"log\": \"Log\",\n",
    "            \"mse\": \"SE\",\n",
    "            \"quadratic\": \"Quad.\",\n",
    "}\n",
    "\n",
    "subset_labels = {\n",
    "    \"cifar\": \"CIFAR-10\",\n",
    "    \"svhn\": \"SVHN\",\n",
    "    \"fashion\": \"Fashion-MNIST\",\n",
    "    \"mixture\": \"Mixture (C,S,F)\",\n",
    "    \"emnist_0\": \"EMNIST (tl)\",\n",
    "    \"emnist_1\": \"EMNIST (tr)\",\n",
    "    \"emnist_2\": \"EMNIST (bl)\",\n",
    "    \"emnist_3\": \"EMNIST (br)\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f23d1f72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{cccccccccccccccccc}\n",
      "\\toprule\n",
      "$\\mathcal{D_{{\\text{ood}}}}$ & SR & $\\hat{R}_{\\text{Tot}}^{1,1}$ & $\\hat{R}_{\\text{Tot}}^{2,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,1}$ & $\\hat{R}_{\\text{Tot}}^{3b,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,2}$ & $\\hat{R}_{\\text{Tot}}^{3b,2}$ & $\\hat{R}_{\\text{Bayes}}^{1}$ & $\\hat{R}_{\\text{Bayes}}^{2}$ & $\\hat{R}_{\\text{Bayes}}^{3a}$ & $\\hat{R}_{\\text{Bayes}}^{3b}$ & $\\hat{R}_{\\text{Exc}}^{1,1}$ & $\\hat{R}_{\\text{Exc}}^{2,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,1}$ & $\\hat{R}_{\\text{Exc}}^{3b,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,2}$ & $\\hat{R}_{\\text{Exc}}^{3b,2}$ \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{CIFAR-10}} & CRPS & \\valvar{0.969}{.018} & \\valvar{0.969}{.018} & \\valvar{0.969}{.018} & \\valvar{0.953}{.022} & \\valvar{0.949}{.021} & \\valvar{0.864}{.006} & \\valvar{0.831}{.015} & \\valvar{0.949}{.021} & \\valvar{0.949}{.021} & \\valvar{0.832}{.015} & \\valvar{0.984}{.011} & \\valvar{0.984}{.011} & \\valvar{0.984}{.011} & \\valvar{0.985}{.011} & \\valvar{0.963}{.016} & \\valvar{0.984}{.012} \\\\\n",
      " & Log & \\valvar{0.977}{.016} & \\valvar{0.977}{.016} & \\valvar{0.972}{.017} & \\valvar{0.962}{.022} & - & - & \\valvar{0.830}{.015} & - & \\valvar{0.949}{.021} & \\valvar{0.832}{.015} & \\valvar{0.984}{.012} & \\valvar{0.984}{.012} & \\valvar{0.984}{.013} & \\valvar{0.984}{.013} & - & - \\\\\n",
      " & MSE & \\valvar{0.966}{.018} & \\valvar{0.966}{.018} & \\valvar{0.966}{.018} & \\valvar{0.949}{.021} & \\valvar{0.949}{.021} & \\valvar{0.832}{.015} & \\valvar{0.832}{.015} & \\valvar{0.949}{.021} & \\valvar{0.949}{.021} & \\valvar{0.832}{.015} & \\valvar{0.982}{.011} & \\valvar{0.982}{.011} & \\valvar{0.982}{.011} & \\valvar{0.982}{.011} & - & - \\\\\n",
      " & Quad. & \\valvar{0.976}{.017} & \\valvar{0.976}{.017} & \\valvar{0.976}{.017} & \\valvar{0.962}{.022} & \\valvar{0.950}{.021} & \\valvar{0.893}{.018} & \\valvar{0.829}{.016} & \\valvar{0.949}{.021} & \\valvar{0.949}{.021} & \\valvar{0.832}{.015} & \\valvar{0.972}{.022} & \\valvar{0.972}{.022} & \\valvar{0.972}{.022} & \\valvar{0.974}{.021} & \\valvar{0.950}{.019} & \\valvar{0.980}{.017} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{SVHN}} & CRPS & \\valvar{0.984}{.011} & \\valvar{0.984}{.011} & \\valvar{0.983}{.011} & \\valvar{0.978}{.013} & \\valvar{0.976}{.012} & \\valvar{0.938}{.000} & \\valvar{0.919}{.012} & \\valvar{0.976}{.013} & \\valvar{0.976}{.012} & \\valvar{0.920}{.012} & \\valvar{0.986}{.012} & \\valvar{0.986}{.012} & \\valvar{0.986}{.012} & \\valvar{0.986}{.012} & \\valvar{0.964}{.022} & \\valvar{0.983}{.016} \\\\\n",
      " & Log & \\valvar{0.986}{.011} & \\valvar{0.986}{.011} & \\valvar{0.985}{.011} & \\valvar{0.981}{.013} & - & - & \\valvar{0.919}{.013} & - & \\valvar{0.976}{.012} & \\valvar{0.920}{.012} & \\valvar{0.980}{.018} & \\valvar{0.980}{.018} & \\valvar{0.980}{.019} & \\valvar{0.979}{.019} & - & - \\\\\n",
      " & MSE & \\valvar{0.982}{.011} & \\valvar{0.982}{.011} & \\valvar{0.982}{.011} & \\valvar{0.976}{.012} & \\valvar{0.976}{.012} & \\valvar{0.920}{.012} & \\valvar{0.920}{.012} & \\valvar{0.976}{.012} & \\valvar{0.976}{.012} & \\valvar{0.920}{.012} & \\valvar{0.987}{.010} & \\valvar{0.987}{.010} & \\valvar{0.987}{.010} & \\valvar{0.987}{.010} & - & - \\\\\n",
      " & Quad. & \\valvar{0.986}{.011} & \\valvar{0.986}{.011} & \\valvar{0.986}{.011} & \\valvar{0.981}{.014} & \\valvar{0.976}{.013} & \\valvar{0.953}{.010} & \\valvar{0.918}{.013} & \\valvar{0.976}{.013} & \\valvar{0.976}{.012} & \\valvar{0.920}{.012} & \\valvar{0.950}{.045} & \\valvar{0.950}{.045} & \\valvar{0.950}{.045} & \\valvar{0.954}{.043} & \\valvar{0.943}{.034} & \\valvar{0.968}{.031} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{Fashion-MNIST}} & CRPS & \\valvar{0.936}{.014} & \\valvar{0.936}{.014} & \\valvar{0.936}{.014} & \\valvar{0.911}{.017} & \\valvar{0.908}{.018} & \\valvar{0.825}{.025} & \\valvar{0.804}{.028} & \\valvar{0.908}{.017} & \\valvar{0.907}{.018} & \\valvar{0.805}{.028} & \\valvar{0.971}{.010} & \\valvar{0.971}{.010} & \\valvar{0.971}{.010} & \\valvar{0.972}{.010} & \\valvar{0.936}{.005} & \\valvar{0.973}{.010} \\\\\n",
      " & Log & \\valvar{0.950}{.012} & \\valvar{0.950}{.012} & \\valvar{0.941}{.013} & \\valvar{0.922}{.016} & - & - & \\valvar{0.804}{.028} & - & \\valvar{0.907}{.018} & \\valvar{0.805}{.028} & \\valvar{0.974}{.011} & \\valvar{0.974}{.011} & \\valvar{0.974}{.011} & \\valvar{0.973}{.011} & - & - \\\\\n",
      " & MSE & \\valvar{0.932}{.014} & \\valvar{0.932}{.014} & \\valvar{0.932}{.014} & \\valvar{0.907}{.018} & \\valvar{0.907}{.018} & \\valvar{0.805}{.028} & \\valvar{0.805}{.028} & \\valvar{0.907}{.018} & \\valvar{0.907}{.018} & \\valvar{0.805}{.028} & \\valvar{0.966}{.010} & \\valvar{0.966}{.010} & \\valvar{0.966}{.010} & \\valvar{0.966}{.010} & - & - \\\\\n",
      " & Quad. & \\valvar{0.946}{.012} & \\valvar{0.946}{.012} & \\valvar{0.947}{.013} & \\valvar{0.921}{.016} & \\valvar{0.908}{.017} & \\valvar{0.844}{.023} & \\valvar{0.803}{.028} & \\valvar{0.907}{.017} & \\valvar{0.907}{.018} & \\valvar{0.805}{.028} & \\valvar{0.963}{.017} & \\valvar{0.963}{.017} & \\valvar{0.963}{.017} & \\valvar{0.964}{.017} & \\valvar{0.922}{.006} & \\valvar{0.970}{.013} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{Mixture (C,S,F)}} & CRPS & \\valvar{0.963}{.013} & \\valvar{0.963}{.013} & \\valvar{0.963}{.013} & \\valvar{0.947}{.015} & \\valvar{0.944}{.015} & \\valvar{0.875}{.010} & \\valvar{0.851}{.017} & \\valvar{0.944}{.015} & \\valvar{0.944}{.015} & \\valvar{0.852}{.017} & \\valvar{0.981}{.010} & \\valvar{0.981}{.010} & \\valvar{0.981}{.010} & \\valvar{0.981}{.010} & \\valvar{0.955}{.014} & \\valvar{0.980}{.012} \\\\\n",
      " & Log & \\valvar{0.971}{.012} & \\valvar{0.971}{.012} & \\valvar{0.966}{.013} & \\valvar{0.955}{.015} & - & - & \\valvar{0.850}{.017} & - & \\valvar{0.944}{.015} & \\valvar{0.852}{.017} & \\valvar{0.979}{.013} & \\valvar{0.979}{.013} & \\valvar{0.979}{.014} & \\valvar{0.979}{.014} & - & - \\\\\n",
      " & MSE & \\valvar{0.960}{.013} & \\valvar{0.960}{.013} & \\valvar{0.960}{.013} & \\valvar{0.944}{.015} & \\valvar{0.944}{.015} & \\valvar{0.852}{.017} & \\valvar{0.852}{.017} & \\valvar{0.944}{.015} & \\valvar{0.944}{.015} & \\valvar{0.852}{.017} & \\valvar{0.979}{.010} & \\valvar{0.979}{.010} & \\valvar{0.979}{.010} & \\valvar{0.979}{.010} & - & - \\\\\n",
      " & Quad. & \\valvar{0.969}{.012} & \\valvar{0.969}{.012} & \\valvar{0.970}{.012} & \\valvar{0.954}{.015} & \\valvar{0.945}{.015} & \\valvar{0.896}{.013} & \\valvar{0.849}{.017} & \\valvar{0.944}{.015} & \\valvar{0.944}{.015} & \\valvar{0.852}{.017} & \\valvar{0.962}{.028} & \\valvar{0.962}{.028} & \\valvar{0.962}{.028} & \\valvar{0.964}{.027} & \\valvar{0.938}{.019} & \\valvar{0.973}{.020} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{EMNIST (tl)}} & CRPS & \\valvar{0.836}{.012} & \\valvar{0.836}{.012} & \\valvar{0.836}{.012} & \\valvar{0.800}{.012} & \\valvar{0.797}{.012} & \\valvar{0.718}{.013} & \\valvar{0.709}{.014} & \\valvar{0.797}{.012} & \\valvar{0.796}{.012} & \\valvar{0.709}{.014} & \\valvar{0.915}{.012} & \\valvar{0.915}{.012} & \\valvar{0.915}{.012} & \\valvar{0.915}{.012} & \\valvar{0.846}{.020} & \\valvar{0.919}{.012} \\\\\n",
      " & Log & \\valvar{0.855}{.012} & \\valvar{0.855}{.012} & \\valvar{0.842}{.012} & \\valvar{0.811}{.012} & - & - & \\valvar{0.709}{.014} & - & \\valvar{0.796}{.012} & \\valvar{0.709}{.014} & \\valvar{0.922}{.011} & \\valvar{0.922}{.011} & \\valvar{0.922}{.011} & \\valvar{0.921}{.011} & - & - \\\\\n",
      " & MSE & \\valvar{0.831}{.012} & \\valvar{0.831}{.012} & \\valvar{0.831}{.012} & \\valvar{0.796}{.012} & \\valvar{0.796}{.012} & \\valvar{0.709}{.014} & \\valvar{0.709}{.014} & \\valvar{0.796}{.012} & \\valvar{0.796}{.012} & \\valvar{0.709}{.014} & \\valvar{0.903}{.012} & \\valvar{0.903}{.012} & \\valvar{0.903}{.012} & \\valvar{0.903}{.012} & - & - \\\\\n",
      " & Quad. & \\valvar{0.850}{.012} & \\valvar{0.850}{.012} & \\valvar{0.850}{.012} & \\valvar{0.809}{.012} & \\valvar{0.797}{.012} & \\valvar{0.731}{.013} & \\valvar{0.708}{.014} & \\valvar{0.797}{.012} & \\valvar{0.796}{.012} & \\valvar{0.709}{.014} & \\valvar{0.910}{.011} & \\valvar{0.910}{.011} & \\valvar{0.910}{.011} & \\valvar{0.911}{.011} & \\valvar{0.827}{.026} & \\valvar{0.919}{.012} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{EMNIST (tr)}} & CRPS & \\valvar{0.559}{.010} & \\valvar{0.559}{.010} & \\valvar{0.559}{.010} & \\valvar{0.550}{.009} & \\valvar{0.550}{.009} & \\valvar{0.534}{.007} & \\valvar{0.533}{.007} & \\valvar{0.550}{.009} & \\valvar{0.550}{.009} & \\valvar{0.533}{.007} & \\valvar{0.594}{.014} & \\valvar{0.594}{.014} & \\valvar{0.594}{.014} & \\valvar{0.594}{.014} & \\valvar{0.580}{.010} & \\valvar{0.595}{.013} \\\\\n",
      " & Log & \\valvar{0.563}{.011} & \\valvar{0.563}{.011} & \\valvar{0.561}{.010} & \\valvar{0.552}{.009} & - & - & \\valvar{0.533}{.007} & - & \\valvar{0.550}{.009} & \\valvar{0.533}{.007} & \\valvar{0.600}{.016} & \\valvar{0.600}{.016} & \\valvar{0.600}{.016} & \\valvar{0.599}{.016} & - & - \\\\\n",
      " & MSE & \\valvar{0.558}{.010} & \\valvar{0.558}{.010} & \\valvar{0.558}{.010} & \\valvar{0.550}{.009} & \\valvar{0.550}{.009} & \\valvar{0.533}{.007} & \\valvar{0.533}{.007} & \\valvar{0.550}{.009} & \\valvar{0.550}{.009} & \\valvar{0.533}{.007} & \\valvar{0.589}{.012} & \\valvar{0.589}{.012} & \\valvar{0.589}{.012} & \\valvar{0.589}{.012} & - & - \\\\\n",
      " & Quad. & \\valvar{0.562}{.010} & \\valvar{0.562}{.010} & \\valvar{0.562}{.010} & \\valvar{0.552}{.009} & \\valvar{0.550}{.009} & \\valvar{0.536}{.007} & \\valvar{0.533}{.007} & \\valvar{0.550}{.009} & \\valvar{0.550}{.009} & \\valvar{0.533}{.007} & \\valvar{0.598}{.016} & \\valvar{0.598}{.016} & \\valvar{0.599}{.016} & \\valvar{0.598}{.016} & \\valvar{0.578}{.009} & \\valvar{0.597}{.014} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{EMNIST (bl)}} & CRPS & \\valvar{0.557}{.022} & \\valvar{0.557}{.022} & \\valvar{0.557}{.022} & \\valvar{0.550}{.021} & \\valvar{0.549}{.021} & \\valvar{0.535}{.019} & \\valvar{0.534}{.018} & \\valvar{0.549}{.021} & \\valvar{0.549}{.021} & \\valvar{0.534}{.018} & \\valvar{0.591}{.022} & \\valvar{0.591}{.022} & \\valvar{0.591}{.022} & \\valvar{0.591}{.022} & \\valvar{0.561}{.013} & \\valvar{0.593}{.021} \\\\\n",
      " & Log & \\valvar{0.560}{.022} & \\valvar{0.560}{.022} & \\valvar{0.558}{.022} & \\valvar{0.551}{.021} & - & - & \\valvar{0.534}{.018} & - & \\valvar{0.549}{.021} & \\valvar{0.534}{.018} & \\valvar{0.595}{.021} & \\valvar{0.595}{.021} & \\valvar{0.594}{.021} & \\valvar{0.594}{.021} & - & - \\\\\n",
      " & MSE & \\valvar{0.556}{.021} & \\valvar{0.556}{.021} & \\valvar{0.556}{.021} & \\valvar{0.549}{.021} & \\valvar{0.549}{.021} & \\valvar{0.534}{.018} & \\valvar{0.534}{.018} & \\valvar{0.549}{.021} & \\valvar{0.549}{.021} & \\valvar{0.534}{.018} & \\valvar{0.586}{.022} & \\valvar{0.586}{.022} & \\valvar{0.586}{.022} & \\valvar{0.586}{.022} & - & - \\\\\n",
      " & Quad. & \\valvar{0.560}{.022} & \\valvar{0.560}{.022} & \\valvar{0.560}{.022} & \\valvar{0.551}{.021} & \\valvar{0.549}{.021} & \\valvar{0.537}{.019} & \\valvar{0.534}{.018} & \\valvar{0.549}{.021} & \\valvar{0.549}{.021} & \\valvar{0.534}{.018} & \\valvar{0.593}{.018} & \\valvar{0.593}{.018} & \\valvar{0.593}{.018} & \\valvar{0.593}{.018} & \\valvar{0.557}{.014} & \\valvar{0.595}{.019} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{EMNIST (br)}} & CRPS & \\valvar{0.548}{.013} & \\valvar{0.548}{.013} & \\valvar{0.548}{.013} & \\valvar{0.541}{.014} & \\valvar{0.541}{.013} & \\valvar{0.529}{.013} & \\valvar{0.528}{.013} & \\valvar{0.541}{.013} & \\valvar{0.541}{.013} & \\valvar{0.528}{.013} & \\valvar{0.577}{.010} & \\valvar{0.577}{.010} & \\valvar{0.577}{.010} & \\valvar{0.577}{.010} & \\valvar{0.553}{.019} & \\valvar{0.577}{.009} \\\\\n",
      " & Log & \\valvar{0.551}{.013} & \\valvar{0.551}{.013} & \\valvar{0.549}{.013} & \\valvar{0.542}{.014} & - & - & \\valvar{0.528}{.013} & - & \\valvar{0.541}{.013} & \\valvar{0.528}{.013} & \\valvar{0.581}{.009} & \\valvar{0.581}{.009} & \\valvar{0.581}{.009} & \\valvar{0.581}{.009} & - & - \\\\\n",
      " & MSE & \\valvar{0.547}{.013} & \\valvar{0.547}{.013} & \\valvar{0.547}{.013} & \\valvar{0.541}{.013} & \\valvar{0.541}{.013} & \\valvar{0.528}{.013} & \\valvar{0.528}{.013} & \\valvar{0.541}{.013} & \\valvar{0.541}{.013} & \\valvar{0.528}{.013} & \\valvar{0.573}{.011} & \\valvar{0.573}{.011} & \\valvar{0.573}{.011} & \\valvar{0.573}{.011} & - & - \\\\\n",
      " & Quad. & \\valvar{0.550}{.013} & \\valvar{0.550}{.013} & \\valvar{0.550}{.013} & \\valvar{0.542}{.014} & \\valvar{0.541}{.013} & \\valvar{0.530}{.013} & \\valvar{0.528}{.013} & \\valvar{0.541}{.013} & \\valvar{0.541}{.013} & \\valvar{0.528}{.013} & \\valvar{0.579}{.008} & \\valvar{0.579}{.008} & \\valvar{0.579}{.008} & \\valvar{0.579}{.008} & \\valvar{0.552}{.020} & \\valvar{0.578}{.008} \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n"
     ]
    }
   ],
   "source": [
    "# print latex table header\n",
    "print(\"\\\\begin{tabular}{cccccccccccccccccc}\")\n",
    "print(\"\\\\toprule\")\n",
    "print(\"$\\\\mathcal{D_{{\\\\text{ood}}}}$ & SR & \" + \" & \".join(labels.values()) + \" \\\\\\\\\")\n",
    "print(\"\\\\midrule\")\n",
    "\n",
    "for subsets in subset_picker:\n",
    "    print(f\"\\\\multirow{{6}}{{*}}{{\\\\rotatebox{{90}}{{{subset_labels[subsets[-1]]}}}}}\", end=\"\")\n",
    "    for sr in scoring_rules:\n",
    "\n",
    "        means, variances = [], []\n",
    "        ood_identity = []\n",
    "        for method_seed in method_seeds:\n",
    "            means_seed, variances_seed = [], []\n",
    "            ood_identity_seed = []\n",
    "            for subset in subsets:\n",
    "                preds = torch.load(\n",
    "                    os.path.join(\n",
    "                        RESULTS_PATH,\n",
    "                        f\"{task}_{network}_natural{use_natural}_mseed{method_seed}_arw{anti_regularization_weight}\",\n",
    "                        f\"preds{subset}.pt\",\n",
    "                    ), weights_only=False\n",
    "                )\n",
    "                means_seed.append(preds[..., 0])\n",
    "                variances_seed.append(preds[..., 1])\n",
    "                ood_identity_seed.append(torch.ones(size=(len(preds), )) if subset != 'test' else torch.zeros(size=(len(preds), )))\n",
    "            means.append(torch.cat(means_seed, dim=0))\n",
    "            variances.append(torch.cat(variances_seed, dim=0))\n",
    "            ood_identity.append(torch.cat(ood_identity_seed, dim=0))\n",
    "\n",
    "        means = torch.stack(means, dim=-2)\n",
    "        variances = torch.stack(variances, dim=-2)\n",
    "        ood_identity = torch.stack(ood_identity, dim=-1)\n",
    "\n",
    "        if sr == \"crps\":\n",
    "            uncertainties = calculate_uncertainties_crps(means, variances)\n",
    "        elif sr == \"log\":\n",
    "            uncertainties = calculate_uncertainties_log(means, variances)\n",
    "        elif sr == \"mse\":\n",
    "            uncertainties = calculate_uncertainties_mse(means, variances)\n",
    "        elif sr == \"quadratic\":\n",
    "            uncertainties = calculate_uncertainties_quadratic(means, variances)\n",
    "\n",
    "        rocs = {}\n",
    "\n",
    "        for unc in uncertainties.keys():\n",
    "            if sr in [e[0] for e in not_considered] and unc in [e[1] for e in not_considered if e[0] == sr]:\n",
    "                rocs[unc] = [float('nan')] * len(method_seeds)\n",
    "            else:\n",
    "                rocs[unc] = [roc_auc_score(ood_identity[:, i].numpy(), uncertainties[unc][:, i].numpy()) for i in range(len(method_seeds))]\n",
    "\n",
    "        print(f\" & {sr_labels[sr]} & \" + \" & \".join([\"\\\\valvar{\"f\"{np.mean(v):.3f}\" + \"}{\" + f\"{np.std(v):.3f}\".lstrip(\"0\") + \"}\" if not torch.isnan(torch.tensor(v)).any() else \"-\" for v in rocs.values()]) + \" \\\\\\\\\")\n",
    "    print(\"\\\\midrule\") if subsets != subset_picker[-1] else None\n",
    "\n",
    "print(\"\\\\bottomrule\")\n",
    "print(\"\\\\end{tabular}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dd2feeb",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
