{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a30ffccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "from source.utils.uncertainty_measures import (\n",
    "        calculate_uncertainties_crps,\n",
    "        calculate_uncertainties_log,\n",
    "        calculate_uncertainties_mse,\n",
    "        calculate_uncertainties_quadratic\n",
    ")\n",
    "\n",
    "from source.constants import RESULTS_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "baa77f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = [\n",
    "    \"dots\",\n",
    "    \"arrow\",\n",
    "    # \"git_qood\",\n",
    "    \"city_street\",\n",
    "    \"city_building\",\n",
    "    \"city_sky\",\n",
    "    \"city_car\",\n",
    "    \"city_vegetation\",\n",
    "]\n",
    "scoring_rules = [\"crps\", \"log\", \"mse\", \"quadratic\"]\n",
    "method_seeds = [42, 142, 242]\n",
    "\n",
    "method = \"ensemble\"\n",
    "anti_regularization_weight = 0.0\n",
    "use_natural = True\n",
    "min_retention_rate = 0.5\n",
    "\n",
    "loss_fn = lambda x, y: torch.nn.functional.mse_loss(x, y, reduction=\"none\")\n",
    "# loss_fn = lambda x, y: torch.nn.functional.l1_loss(x, y, reduction=\"none\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eafb87a",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = {\n",
    "            \"total_1_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{1,1}$\",\n",
    "            \"total_2_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{2,1}$\",\n",
    "            \"total_3a_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,1}$\",\n",
    "            \"total_3b_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,1}$\",\n",
    "            \"total_3a_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,2}$\",\n",
    "            \"total_3b_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,2}$\",\n",
    "            \"bayes_1\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{1}$\",\n",
    "            \"bayes_2\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{2}$\",\n",
    "            \"bayes_3a\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3a}$\",\n",
    "            \"bayes_3b\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3b}$\",\n",
    "            \"excess_1_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{1,1}$\",\n",
    "            \"excess_2_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{2,1}$\",\n",
    "            \"excess_3a_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,1}$\",\n",
    "            \"excess_3b_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,1}$\",\n",
    "            \"excess_3a_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,2}$\",\n",
    "            \"excess_3b_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,2}$\",\n",
    "        }\n",
    "\n",
    "sr_labels = {\n",
    "            \"crps\": \"CRPS\",\n",
    "            \"log\": \"Log\",\n",
    "            \"mse\": \"SE\",\n",
    "            \"quadratic\": \"Quad.\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f23d1f72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{ccccccccccccccccccc}\n",
      "\\toprule\n",
      "$\\mathcal{D}$ & SR & $\\hat{R}_{\\text{Tot}}^{1,1}$ & $\\hat{R}_{\\text{Tot}}^{2,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,1}$ & $\\hat{R}_{\\text{Tot}}^{3b,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,2}$ & $\\hat{R}_{\\text{Tot}}^{3b,2}$ & $\\hat{R}_{\\text{Bayes}}^{1}$ & $\\hat{R}_{\\text{Bayes}}^{2}$ & $\\hat{R}_{\\text{Bayes}}^{3a}$ & $\\hat{R}_{\\text{Bayes}}^{3b}$ & $\\hat{R}_{\\text{Exc}}^{1,1}$ & $\\hat{R}_{\\text{Exc}}^{2,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,1}$ & $\\hat{R}_{\\text{Exc}}^{3b,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,2}$ & $\\hat{R}_{\\text{Exc}}^{3b,2}$ \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{dots}} & CRPS & \\valvar{0.631}{.002} & \\valvar{0.631}{.002} & \\valvar{0.631}{.002} & \\valvar{0.644}{.002} & \\valvar{0.644}{.002} & \\valvar{0.664}{.002} & \\valvar{0.664}{.002} & \\valvar{0.644}{.002} & \\valvar{0.644}{.002} & \\valvar{0.664}{.002} & \\valvar{0.704}{.001} & \\valvar{0.704}{.001} & \\valvar{0.704}{.001} & \\valvar{0.705}{.001} & \\valvar{0.782}{.002} & \\valvar{0.716}{.001} \\\\\n",
      " & Log & \\valvar{0.631}{.002} & \\valvar{0.631}{.002} & \\valvar{0.631}{.002} & \\valvar{0.643}{.002} & - & - & \\valvar{0.664}{.002} & - & \\valvar{0.644}{.002} & \\valvar{0.664}{.002} & \\valvar{0.729}{.001} & \\valvar{0.729}{.001} & \\valvar{0.729}{.001} & \\valvar{0.729}{.001} & - & - \\\\\n",
      " & MSE & \\valvar{0.632}{.002} & \\valvar{0.632}{.002} & \\valvar{0.632}{.002} & \\valvar{0.644}{.002} & \\valvar{0.644}{.002} & \\valvar{0.664}{.002} & \\valvar{0.664}{.002} & \\valvar{0.644}{.002} & \\valvar{0.644}{.002} & \\valvar{0.664}{.002} & \\valvar{0.685}{.001} & \\valvar{0.685}{.001} & \\valvar{0.685}{.001} & \\valvar{0.685}{.001} & - & - \\\\\n",
      " & Quad. & \\valvar{0.631}{.002} & \\valvar{0.631}{.002} & \\valvar{0.631}{.002} & \\valvar{0.643}{.002} & \\valvar{0.644}{.002} & \\valvar{0.663}{.002} & \\valvar{0.664}{.002} & \\valvar{0.644}{.002} & \\valvar{0.644}{.002} & \\valvar{0.664}{.002} & \\valvar{0.760}{.002} & \\valvar{0.760}{.002} & \\valvar{0.760}{.002} & \\valvar{0.759}{.002} & \\valvar{0.805}{.002} & \\valvar{0.743}{.001} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{arrow}} & CRPS & \\valvar{0.776}{.017} & \\valvar{0.776}{.017} & \\valvar{0.776}{.017} & \\valvar{0.785}{.018} & \\valvar{0.786}{.018} & \\valvar{0.806}{.018} & \\valvar{0.810}{.017} & \\valvar{0.786}{.018} & \\valvar{0.786}{.018} & \\valvar{0.810}{.017} & \\valvar{0.756}{.012} & \\valvar{0.756}{.012} & \\valvar{0.756}{.012} & \\valvar{0.756}{.012} & \\valvar{0.804}{.014} & \\valvar{0.762}{.011} \\\\\n",
      " & Log & \\valvar{0.773}{.017} & \\valvar{0.773}{.017} & \\valvar{0.775}{.017} & \\valvar{0.783}{.018} & - & - & \\valvar{0.811}{.017} & - & \\valvar{0.786}{.018} & \\valvar{0.810}{.017} & \\valvar{0.765}{.010} & \\valvar{0.765}{.010} & \\valvar{0.765}{.010} & \\valvar{0.766}{.011} & - & - \\\\\n",
      " & MSE & \\valvar{0.777}{.017} & \\valvar{0.777}{.017} & \\valvar{0.777}{.017} & \\valvar{0.786}{.018} & \\valvar{0.786}{.018} & \\valvar{0.810}{.017} & \\valvar{0.810}{.017} & \\valvar{0.786}{.018} & \\valvar{0.786}{.018} & \\valvar{0.810}{.017} & \\valvar{0.753}{.013} & \\valvar{0.753}{.013} & \\valvar{0.753}{.013} & \\valvar{0.753}{.013} & - & - \\\\\n",
      " & Quad. & \\valvar{0.774}{.017} & \\valvar{0.774}{.017} & \\valvar{0.774}{.017} & \\valvar{0.784}{.018} & \\valvar{0.785}{.018} & \\valvar{0.801}{.018} & \\valvar{0.811}{.017} & \\valvar{0.786}{.018} & \\valvar{0.786}{.018} & \\valvar{0.810}{.017} & \\valvar{0.783}{.010} & \\valvar{0.783}{.010} & \\valvar{0.783}{.010} & \\valvar{0.782}{.010} & \\valvar{0.815}{.014} & \\valvar{0.776}{.010} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{street}} & CRPS & \\valvar{0.203}{.013} & \\valvar{0.203}{.013} & \\valvar{0.204}{.013} & \\valvar{0.209}{.014} & \\valvar{0.209}{.014} & \\valvar{0.232}{.014} & \\valvar{0.236}{.014} & \\valvar{0.209}{.014} & \\valvar{0.209}{.014} & \\valvar{0.237}{.014} & \\valvar{0.295}{.012} & \\valvar{0.295}{.012} & \\valvar{0.295}{.012} & \\valvar{0.298}{.012} & \\valvar{0.605}{.016} & \\valvar{0.407}{.012} \\\\\n",
      " & Log & \\valvar{0.215}{.013} & \\valvar{0.215}{.013} & \\valvar{0.206}{.013} & \\valvar{0.209}{.014} & - & - & \\valvar{0.237}{.014} & - & \\valvar{0.209}{.014} & \\valvar{0.237}{.014} & \\valvar{0.479}{.014} & \\valvar{0.479}{.014} & \\valvar{0.488}{.014} & \\valvar{0.488}{.014} & - & - \\\\\n",
      " & MSE & \\valvar{0.204}{.013} & \\valvar{0.204}{.013} & \\valvar{0.204}{.013} & \\valvar{0.209}{.014} & \\valvar{0.209}{.014} & \\valvar{0.237}{.014} & \\valvar{0.237}{.014} & \\valvar{0.209}{.014} & \\valvar{0.209}{.014} & \\valvar{0.237}{.014} & \\valvar{0.238}{.011} & \\valvar{0.238}{.011} & \\valvar{0.238}{.011} & \\valvar{0.238}{.011} & - & - \\\\\n",
      " & Quad. & \\valvar{0.206}{.013} & \\valvar{0.206}{.013} & \\valvar{0.211}{.013} & \\valvar{0.208}{.014} & \\valvar{0.210}{.014} & \\valvar{0.228}{.014} & \\valvar{0.238}{.014} & \\valvar{0.208}{.014} & \\valvar{0.209}{.014} & \\valvar{0.237}{.014} & \\valvar{0.896}{.010} & \\valvar{0.896}{.010} & \\valvar{0.907}{.011} & \\valvar{0.883}{.011} & \\valvar{1.063}{.011} & \\valvar{0.787}{.013} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{building}} & CRPS & \\valvar{0.227}{.012} & \\valvar{0.227}{.012} & \\valvar{0.227}{.012} & \\valvar{0.235}{.012} & \\valvar{0.236}{.012} & \\valvar{0.264}{.012} & \\valvar{0.273}{.012} & \\valvar{0.237}{.012} & \\valvar{0.236}{.012} & \\valvar{0.270}{.012} & \\valvar{0.240}{.011} & \\valvar{0.240}{.011} & \\valvar{0.242}{.011} & \\valvar{0.242}{.011} & \\valvar{0.551}{.012} & \\valvar{0.318}{.013} \\\\\n",
      " & Log & \\valvar{0.227}{.012} & \\valvar{0.227}{.012} & \\valvar{0.227}{.012} & \\valvar{0.232}{.012} & - & - & \\valvar{0.277}{.012} & - & \\valvar{0.236}{.012} & \\valvar{0.270}{.012} & \\valvar{0.336}{.013} & \\valvar{0.336}{.013} & \\valvar{0.338}{.013} & \\valvar{0.332}{.013} & - & - \\\\\n",
      " & MSE & \\valvar{0.228}{.012} & \\valvar{0.228}{.012} & \\valvar{0.228}{.012} & \\valvar{0.236}{.012} & \\valvar{0.236}{.012} & \\valvar{0.270}{.012} & \\valvar{0.270}{.012} & \\valvar{0.236}{.012} & \\valvar{0.236}{.012} & \\valvar{0.270}{.012} & \\valvar{0.229}{.011} & \\valvar{0.229}{.011} & \\valvar{0.229}{.011} & \\valvar{0.229}{.011} & - & - \\\\\n",
      " & Quad. & \\valvar{0.226}{.012} & \\valvar{0.226}{.012} & \\valvar{0.227}{.012} & \\valvar{0.233}{.013} & \\valvar{0.235}{.012} & \\valvar{0.257}{.013} & \\valvar{0.282}{.011} & \\valvar{0.242}{.013} & \\valvar{0.236}{.012} & \\valvar{0.270}{.012} & \\valvar{0.517}{.012} & \\valvar{0.517}{.012} & \\valvar{0.538}{.012} & \\valvar{0.523}{.012} & \\valvar{0.888}{.011} & \\valvar{0.568}{.013} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{sky}} & CRPS & \\valvar{0.043}{.003} & \\valvar{0.043}{.003} & \\valvar{0.044}{.003} & \\valvar{0.045}{.004} & \\valvar{0.045}{.004} & \\valvar{0.050}{.006} & \\valvar{0.051}{.006} & \\valvar{0.044}{.004} & \\valvar{0.045}{.004} & \\valvar{0.052}{.006} & \\valvar{0.044}{.001} & \\valvar{0.044}{.001} & \\valvar{0.045}{.001} & \\valvar{0.044}{.001} & \\valvar{0.099}{.004} & \\valvar{0.050}{.001} \\\\\n",
      " & Log & \\valvar{0.046}{.003} & \\valvar{0.046}{.003} & \\valvar{0.045}{.003} & \\valvar{0.046}{.004} & - & - & \\valvar{0.050}{.006} & - & \\valvar{0.045}{.004} & \\valvar{0.052}{.006} & \\valvar{0.054}{.001} & \\valvar{0.054}{.001} & \\valvar{0.057}{.001} & \\valvar{0.056}{.001} & - & - \\\\\n",
      " & MSE & \\valvar{0.044}{.004} & \\valvar{0.044}{.004} & \\valvar{0.044}{.004} & \\valvar{0.045}{.004} & \\valvar{0.045}{.004} & \\valvar{0.052}{.006} & \\valvar{0.052}{.006} & \\valvar{0.045}{.004} & \\valvar{0.045}{.004} & \\valvar{0.052}{.006} & \\valvar{0.043}{.002} & \\valvar{0.043}{.002} & \\valvar{0.043}{.002} & \\valvar{0.043}{.002} & - & - \\\\\n",
      " & Quad. & \\valvar{0.041}{.003} & \\valvar{0.041}{.003} & \\valvar{0.045}{.003} & \\valvar{0.044}{.004} & \\valvar{0.046}{.004} & \\valvar{0.048}{.006} & \\valvar{0.049}{.006} & \\valvar{0.042}{.004} & \\valvar{0.045}{.004} & \\valvar{0.052}{.006} & \\valvar{0.075}{.003} & \\valvar{0.075}{.003} & \\valvar{0.079}{.003} & \\valvar{0.073}{.003} & \\valvar{0.172}{.004} & \\valvar{0.072}{.002} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{car}} & CRPS & \\valvar{0.149}{.003} & \\valvar{0.149}{.003} & \\valvar{0.154}{.003} & \\valvar{0.162}{.003} & \\valvar{0.162}{.003} & \\valvar{0.189}{.004} & \\valvar{0.186}{.005} & \\valvar{0.156}{.003} & \\valvar{0.161}{.003} & \\valvar{0.196}{.005} & \\valvar{0.151}{.006} & \\valvar{0.151}{.006} & \\valvar{0.154}{.006} & \\valvar{0.154}{.006} & \\valvar{0.347}{.008} & \\valvar{0.195}{.008} \\\\\n",
      " & Log & \\valvar{0.176}{.004} & \\valvar{0.176}{.004} & \\valvar{0.163}{.003} & \\valvar{0.172}{.003} & - & - & \\valvar{0.177}{.004} & - & \\valvar{0.161}{.003} & \\valvar{0.196}{.005} & \\valvar{0.219}{.007} & \\valvar{0.219}{.007} & \\valvar{0.226}{.008} & \\valvar{0.219}{.009} & - & - \\\\\n",
      " & MSE & \\valvar{0.153}{.003} & \\valvar{0.153}{.003} & \\valvar{0.153}{.003} & \\valvar{0.161}{.003} & \\valvar{0.161}{.003} & \\valvar{0.196}{.005} & \\valvar{0.196}{.005} & \\valvar{0.161}{.003} & \\valvar{0.161}{.003} & \\valvar{0.196}{.005} & \\valvar{0.141}{.005} & \\valvar{0.141}{.005} & \\valvar{0.141}{.005} & \\valvar{0.141}{.005} & - & - \\\\\n",
      " & Quad. & \\valvar{0.153}{.005} & \\valvar{0.153}{.005} & \\valvar{0.174}{.004} & \\valvar{0.174}{.003} & \\valvar{0.170}{.003} & \\valvar{0.187}{.004} & \\valvar{0.170}{.004} & \\valvar{0.148}{.003} & \\valvar{0.161}{.003} & \\valvar{0.196}{.005} & \\valvar{0.352}{.016} & \\valvar{0.352}{.016} & \\valvar{0.377}{.014} & \\valvar{0.342}{.015} & \\valvar{0.617}{.012} & \\valvar{0.334}{.013} \\\\\n",
      "\\midrule\n",
      "\\multirow{6}{*}{\\rotatebox{90}{vegetation}} & CRPS & \\valvar{0.200}{.011} & \\valvar{0.200}{.011} & \\valvar{0.199}{.011} & \\valvar{0.209}{.011} & \\valvar{0.210}{.011} & \\valvar{0.260}{.011} & \\valvar{0.274}{.012} & \\valvar{0.211}{.011} & \\valvar{0.210}{.011} & \\valvar{0.272}{.011} & \\valvar{0.187}{.009} & \\valvar{0.187}{.009} & \\valvar{0.186}{.009} & \\valvar{0.187}{.009} & \\valvar{0.337}{.005} & \\valvar{0.199}{.008} \\\\\n",
      " & Log & \\valvar{0.195}{.011} & \\valvar{0.195}{.011} & \\valvar{0.197}{.011} & \\valvar{0.203}{.010} & - & - & \\valvar{0.276}{.012} & - & \\valvar{0.210}{.011} & \\valvar{0.272}{.011} & \\valvar{0.195}{.008} & \\valvar{0.195}{.008} & \\valvar{0.195}{.008} & \\valvar{0.196}{.008} & - & - \\\\\n",
      " & MSE & \\valvar{0.200}{.011} & \\valvar{0.200}{.011} & \\valvar{0.200}{.011} & \\valvar{0.210}{.011} & \\valvar{0.210}{.011} & \\valvar{0.272}{.011} & \\valvar{0.272}{.011} & \\valvar{0.210}{.011} & \\valvar{0.210}{.011} & \\valvar{0.272}{.011} & \\valvar{0.187}{.011} & \\valvar{0.187}{.011} & \\valvar{0.187}{.011} & \\valvar{0.187}{.011} & - & - \\\\\n",
      " & Quad. & \\valvar{0.199}{.012} & \\valvar{0.199}{.012} & \\valvar{0.196}{.011} & \\valvar{0.204}{.011} & \\valvar{0.209}{.011} & \\valvar{0.247}{.011} & \\valvar{0.278}{.012} & \\valvar{0.214}{.012} & \\valvar{0.210}{.011} & \\valvar{0.272}{.011} & \\valvar{0.217}{.008} & \\valvar{0.217}{.008} & \\valvar{0.216}{.008} & \\valvar{0.216}{.008} & \\valvar{0.462}{.005} & \\valvar{0.241}{.008} \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n"
     ]
    }
   ],
   "source": [
    "# print latex table header\n",
    "print(\"\\\\begin{tabular}{cc|cccccc|cccc|cccccc}\")\n",
    "print(\"\\\\toprule\")\n",
    "print(\"$\\\\mathcal{D}$ & SR & \" + \" & \".join(labels.values()) + \" \\\\\\\\\")\n",
    "print(\"\\\\midrule\")\n",
    "\n",
    "for task in tasks:\n",
    "    network = [\"cnn\", \"resnet\"][0 if task in [\"dots\", \"arrow\", \"git_qood\"] else 1]\n",
    "    print(f\"\\\\multirow{{6}}{{*}}{{\\\\rotatebox{{90}}{{{task.replace('git_qood', 'mosaic').replace('city_', '')}}}}}\", end=\"\")\n",
    "    for sr in scoring_rules:\n",
    "        y_test = torch.load(\n",
    "            os.path.join(\n",
    "                RESULTS_PATH,\n",
    "                f\"{task}_{network}_natural{use_natural}_mseed{method_seeds[0]}_arw{anti_regularization_weight}\",\n",
    "                f\"y_test.pt\",\n",
    "            ), weights_only=False\n",
    "        )\n",
    "\n",
    "        # extend y_test to match the number of method seeds\n",
    "        y_test = y_test.unsqueeze(-1).expand(-1, len(method_seeds))\n",
    "\n",
    "        means, variances = [], []\n",
    "        for method_seed in method_seeds:\n",
    "            preds = torch.load(\n",
    "                os.path.join(\n",
    "                    RESULTS_PATH,\n",
    "                    f\"{task}_{network}_natural{use_natural}_mseed{method_seed}_arw{anti_regularization_weight}\",\n",
    "                    f\"preds.pt\",\n",
    "                ), weights_only=False\n",
    "            )\n",
    "            means.append(preds[..., 0])\n",
    "            variances.append(preds[..., 1])\n",
    "\n",
    "        means = torch.stack(means, dim=-2)\n",
    "        variances = torch.stack(variances, dim=-2)\n",
    "\n",
    "        if sr == \"crps\":\n",
    "            uncertainties = calculate_uncertainties_crps(means, variances)\n",
    "        elif sr == \"log\":\n",
    "            uncertainties = calculate_uncertainties_log(means, variances)\n",
    "        elif sr == \"mse\":\n",
    "            uncertainties = calculate_uncertainties_mse(means, variances)\n",
    "        elif sr == \"quadratic\":\n",
    "            uncertainties = calculate_uncertainties_quadratic(means, variances)\n",
    "        \n",
    "        # calculate retention rates\n",
    "        error = loss_fn(means.mean(dim=-1), y_test)\n",
    "        torch.manual_seed(2357)\n",
    "        error_rand = error[torch.randperm(error.shape[0])].view(error.shape[0], -1)\n",
    "        error_sort = error.sort(dim=0).values\n",
    "\n",
    "        retention_rates, oracle_errors, random_errors = [], [], []\n",
    "        risk_errors = {k: [] for k in uncertainties.keys()}\n",
    "        all_indices = {k: v[..., 0].sort().indices for k, v in uncertainties.items()}\n",
    "\n",
    "        for i in tqdm(range(int(len(means) * min_retention_rate), len(means)), disable=True):\n",
    "            retention_rates.append((i + 1) / len(means))\n",
    "\n",
    "            _oracle_errors = []\n",
    "            _random_errors = []\n",
    "            _risk_errors = {k: [] for k in uncertainties.keys()}\n",
    "\n",
    "            for m in range(len(method_seeds)):\n",
    "                _oracle_errors.append(error_sort[:i, m].mean().item())\n",
    "                _random_errors.append(error_rand[:i, m].mean().item())\n",
    "\n",
    "                for k, v in uncertainties.items():\n",
    "                    if len(v) <= 1 and torch.isnan(v):\n",
    "                        continue\n",
    "                    indices = all_indices[k][:i]\n",
    "                    _risk_errors[k].append(error[indices, m].mean().item())\n",
    "\n",
    "            oracle_errors.append(_oracle_errors)\n",
    "            random_errors.append(_random_errors)\n",
    "            for k, v in _risk_errors.items():\n",
    "                risk_errors[k].append(v)\n",
    "\n",
    "        oracle_errors = torch.tensor(oracle_errors)\n",
    "        random_errors = torch.tensor(random_errors)\n",
    "        risk_errors = {k: torch.tensor(v) for k, v in risk_errors.items()}\n",
    "\n",
    "        # calculate areas\n",
    "        random_area = random_errors.mean(dim=(0))\n",
    "        oracle_area = oracle_errors.mean(dim=(0))\n",
    "\n",
    "        areas = {\n",
    "            \"total_1_1\": [(risk_errors[\"total_1_1\"].mean(dim=(0))[i] - oracle_area[i]) / \n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"total_2_1\": [(risk_errors[\"total_2_1\"].mean(dim=(0))[i] - oracle_area[i]) / \n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"total_3a_1\": [(risk_errors[\"total_3a_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"total_3b_1\": [(risk_errors[\"total_3b_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"total_3a_2\": [(risk_errors[\"total_3a_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                          if sr != \"log\" else [float('nan')] * len(method_seeds),\n",
    "            \"total_3b_2\": [(risk_errors[\"total_3b_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                          if sr != \"log\" else [float('nan')] * len(method_seeds),\n",
    "            \"bayes_1\": [(risk_errors[\"bayes_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"bayes_2\": [(risk_errors[\"bayes_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                        if sr != \"log\" else [float('nan')] * len(method_seeds),\n",
    "            \"bayes_3a\": [(risk_errors[\"bayes_3a\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"bayes_3b\": [(risk_errors[\"bayes_3b\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_1_1\": [(risk_errors[\"excess_1_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_2_1\": [(risk_errors[\"excess_2_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_3a_1\": [(risk_errors[\"excess_3a_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_3b_1\": [(risk_errors[\"excess_3b_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_3a_2\": [(risk_errors[\"excess_3a_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                          if sr not in [\"log\", \"mse\"] else [float('nan')] * len(method_seeds),\n",
    "            \"excess_3b_2\": [(risk_errors[\"excess_3b_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                          if sr not in [\"log\", \"mse\"] else [float('nan')] * len(method_seeds),\n",
    "        }\n",
    "\n",
    "        print(f\" & {sr_labels[sr]} & \" + \" & \".join([\"\\\\valvar{\"f\"{np.mean(v):.3f}\" + \"}{\" + f\"{np.std(v):.3f}\".lstrip(\"0\") + \"}\" if not torch.isnan(torch.tensor(v)).any() else \"-\" for v in areas.values()]) + \" \\\\\\\\\")\n",
    "    print(\"\\\\midrule\") if task != tasks[-1] else None\n",
    "\n",
    "print(\"\\\\bottomrule\")\n",
    "print(\"\\\\end{tabular}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dd2feeb",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
