{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a30ffccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "from source.utils.uncertainty_measures import (\n",
    "        calculate_uncertainties_crps,\n",
    "        calculate_uncertainties_log,\n",
    "        calculate_uncertainties_mse,\n",
    "        calculate_uncertainties_quadratic\n",
    ")\n",
    "\n",
    "from source.constants import RESULTS_PATH, PLOTS_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "baa77f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = [\n",
    "    \"dots\",\n",
    "    \"arrow\",\n",
    "    # \"git_qood\",\n",
    "    \"city_street\",\n",
    "    \"city_building\",\n",
    "    \"city_sky\",\n",
    "    \"city_car\",\n",
    "    \"city_vegetation\",\n",
    "]\n",
    "scoring_rules = [\"crps\", \"log\" , \"mse\", \"quadratic\"]\n",
    "method_seeds = [42, 142, 242]\n",
    "\n",
    "method = \"ensemble\"\n",
    "anti_regularization_weight = 0.0\n",
    "use_natural = True\n",
    "min_retention_rate = 0.5\n",
    "\n",
    "avg_rank = False\n",
    "\n",
    "loss_fn = lambda x, y: torch.nn.functional.mse_loss(x, y, reduction=\"none\")\n",
    "# loss_fn = lambda x, y: torch.nn.functional.l1_loss(x, y, reduction=\"none\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eafb87a",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = {\n",
    "            \"total_1_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{1,1}$\",\n",
    "            \"total_2_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{2,1}$\",\n",
    "            \"total_3a_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,1}$\",\n",
    "            \"total_3b_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,1}$\",\n",
    "            \"total_3a_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,2}$\",\n",
    "            \"total_3b_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,2}$\",\n",
    "            \"bayes_1\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{1}$\",\n",
    "            \"bayes_2\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{2}$\",\n",
    "            \"bayes_3a\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3a}$\",\n",
    "            \"bayes_3b\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3b}$\",\n",
    "            \"excess_1_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{1,1}$\",\n",
    "            \"excess_2_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{2,1}$\",\n",
    "            \"excess_3a_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,1}$\",\n",
    "            \"excess_3b_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,1}$\",\n",
    "            \"excess_3a_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,2}$\",\n",
    "            \"excess_3b_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,2}$\",\n",
    "        }\n",
    "\n",
    "sr_labels = {\n",
    "            \"crps\": \"CRPS\",\n",
    "            \"log\": \"Log\",\n",
    "            \"mse\": \"SE\",\n",
    "            \"quadratic\": \"Quad.\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f23d1f72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{c|cccccc|cccc|cccccc}\n",
      "\\toprule\n",
      " SR & $\\hat{R}_{\\text{Tot}}^{1,1}$ & $\\hat{R}_{\\text{Tot}}^{2,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,1}$ & $\\hat{R}_{\\text{Tot}}^{3b,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,2}$ & $\\hat{R}_{\\text{Tot}}^{3b,2}$ & $\\hat{R}_{\\text{Bayes}}^{1}$ & $\\hat{R}_{\\text{Bayes}}^{2}$ & $\\hat{R}_{\\text{Bayes}}^{3a}$ & $\\hat{R}_{\\text{Bayes}}^{3b}$ & $\\hat{R}_{\\text{Exc}}^{1,1}$ & $\\hat{R}_{\\text{Exc}}^{2,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,1}$ & $\\hat{R}_{\\text{Exc}}^{3b,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,2}$ & $\\hat{R}_{\\text{Exc}}^{3b,2}$ \\\\\n",
      "\\midrule\n",
      " CRPS & \\valvar{0.318}{.007} & \\valvar{0.318}{.007} & \\valvar{0.319}{.007} & \\valvar{0.327}{.007} & \\valvar{0.327}{.007} & \\valvar{0.352}{.007} & \\valvar{0.356}{.007} & \\valvar{0.327}{.007} & \\valvar{0.327}{.007} & \\valvar{0.357}{.007} & \\valvar{0.339}{.006} & \\valvar{0.339}{.006} & \\valvar{0.340}{.006} & \\valvar{0.341}{.006} & \\valvar{0.504}{.006} & \\valvar{0.378}{.006} \\\\\n",
      " Log & \\valvar{0.323}{.007} & \\valvar{0.323}{.007} & \\valvar{0.321}{.007} & \\valvar{0.327}{.007} & - & - & \\valvar{0.356}{.007} & - & \\valvar{0.327}{.007} & \\valvar{0.357}{.007} & \\valvar{0.397}{.006} & \\valvar{0.397}{.006} & \\valvar{0.400}{.006} & \\valvar{0.398}{.007} & - & - \\\\\n",
      " MSE & \\valvar{0.320}{.007} & \\valvar{0.320}{.007} & \\valvar{0.320}{.007} & \\valvar{0.327}{.007} & \\valvar{0.327}{.007} & \\valvar{0.357}{.007} & \\valvar{0.357}{.007} & \\valvar{0.327}{.007} & \\valvar{0.327}{.007} & \\valvar{0.357}{.007} & \\valvar{0.325}{.006} & \\valvar{0.325}{.006} & \\valvar{0.325}{.006} & \\valvar{0.325}{.006} & - & - \\\\\n",
      " Quad. & \\valvar{0.319}{.007} & \\valvar{0.319}{.007} & \\valvar{0.323}{.007} & \\valvar{0.327}{.007} & \\valvar{0.329}{.007} & \\valvar{0.347}{.007} & \\valvar{0.356}{.007} & \\valvar{0.326}{.007} & \\valvar{0.327}{.007} & \\valvar{0.357}{.007} & \\valvar{0.514}{.007} & \\valvar{0.514}{.007} & \\valvar{0.523}{.006} & \\valvar{0.511}{.007} & \\valvar{0.689}{.004} & \\valvar{0.503}{.007} \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n"
     ]
    }
   ],
   "source": [
    "# print latex table header\n",
    "print(\"\\\\begin{tabular}{c|cccccc|cccc|cccccc}\")\n",
    "print(\"\\\\toprule\")\n",
    "print(\" SR & \" + \" & \".join(labels.values()) + \" \\\\\\\\\")\n",
    "print(\"\\\\midrule\")\n",
    "\n",
    "\n",
    "for sr in scoring_rules:\n",
    "    all_areas = {k: {} for k in scoring_rules}    \n",
    "    for task in tasks:\n",
    "        network = [\"cnn\", \"resnet\"][0 if task in [\"dots\", \"arrow\", \"git_qood\"] else 1]   \n",
    "        y_test = torch.load(\n",
    "            os.path.join(\n",
    "                RESULTS_PATH,\n",
    "                f\"{task}_{network}_natural{use_natural}_mseed{method_seeds[0]}_arw{anti_regularization_weight}\",\n",
    "                f\"y_test.pt\",\n",
    "            ), weights_only=False\n",
    "        )\n",
    "\n",
    "        # extend y_test to match the number of method seeds\n",
    "        y_test = y_test.unsqueeze(-1).expand(-1, len(method_seeds))\n",
    "\n",
    "        means, variances = [], []\n",
    "        for method_seed in method_seeds:\n",
    "            preds = torch.load(\n",
    "                os.path.join(\n",
    "                    RESULTS_PATH,\n",
    "                    f\"{task}_{network}_natural{use_natural}_mseed{method_seed}_arw{anti_regularization_weight}\",\n",
    "                    f\"preds.pt\",\n",
    "                ), weights_only=False\n",
    "            )\n",
    "            means.append(preds[..., 0])\n",
    "            variances.append(preds[..., 1])\n",
    "\n",
    "        means = torch.stack(means, dim=-2)\n",
    "        variances = torch.stack(variances, dim=-2)\n",
    "\n",
    "        if sr == \"crps\":\n",
    "            uncertainties = calculate_uncertainties_crps(means, variances)\n",
    "        elif sr == \"log\":\n",
    "            uncertainties = calculate_uncertainties_log(means, variances)\n",
    "        elif sr == \"mse\":\n",
    "            uncertainties = calculate_uncertainties_mse(means, variances)\n",
    "        elif sr == \"quadratic\":\n",
    "            uncertainties = calculate_uncertainties_quadratic(means, variances)\n",
    "        \n",
    "        # calculate retention rates\n",
    "        error = loss_fn(means.mean(dim=-1), y_test)\n",
    "        torch.manual_seed(2357)\n",
    "        error_rand = error[torch.randperm(error.shape[0])].view(error.shape[0], -1)\n",
    "        error_sort = error.sort(dim=0).values\n",
    "\n",
    "        retention_rates, oracle_errors, random_errors = [], [], []\n",
    "        risk_errors = {k: [] for k in uncertainties.keys()}\n",
    "        all_indices = {k: v[..., 0].sort().indices for k, v in uncertainties.items()}\n",
    "\n",
    "        for i in tqdm(range(int(len(means) * min_retention_rate), len(means)), disable=True):\n",
    "            retention_rates.append((i + 1) / len(means))\n",
    "\n",
    "            _oracle_errors = []\n",
    "            _random_errors = []\n",
    "            _risk_errors = {k: [] for k in uncertainties.keys()}\n",
    "\n",
    "            for m in range(len(method_seeds)):\n",
    "                _oracle_errors.append(error_sort[:i, m].mean().item())\n",
    "                _random_errors.append(error_rand[:i, m].mean().item())\n",
    "\n",
    "                for k, v in uncertainties.items():\n",
    "                    if len(v) <= 1 and torch.isnan(v):\n",
    "                        continue\n",
    "                    indices = all_indices[k][:i]\n",
    "                    _risk_errors[k].append(error[indices, m].mean().item())\n",
    "\n",
    "            oracle_errors.append(_oracle_errors)\n",
    "            random_errors.append(_random_errors)\n",
    "            for k, v in _risk_errors.items():\n",
    "                risk_errors[k].append(v)\n",
    "\n",
    "        oracle_errors = torch.tensor(oracle_errors)\n",
    "        random_errors = torch.tensor(random_errors)\n",
    "        risk_errors = {k: torch.tensor(v) for k, v in risk_errors.items()}\n",
    "\n",
    "        # calculate areas\n",
    "        random_area = random_errors.mean(dim=(0))\n",
    "        oracle_area = oracle_errors.mean(dim=(0))\n",
    "\n",
    "        areas = {\n",
    "            \"total_1_1\": [(risk_errors[\"total_1_1\"].mean(dim=(0))[i] - oracle_area[i]) / \n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"total_2_1\": [(risk_errors[\"total_2_1\"].mean(dim=(0))[i] - oracle_area[i]) / \n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"total_3a_1\": [(risk_errors[\"total_3a_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"total_3b_1\": [(risk_errors[\"total_3b_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"total_3a_2\": [(risk_errors[\"total_3a_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                          if sr != \"log\" else [float('nan')] * len(method_seeds),\n",
    "            \"total_3b_2\": [(risk_errors[\"total_3b_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                          if sr != \"log\" else [float('nan')] * len(method_seeds),\n",
    "            \"bayes_1\": [(risk_errors[\"bayes_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"bayes_2\": [(risk_errors[\"bayes_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                        if sr != \"log\" else [float('nan')] * len(method_seeds),\n",
    "            \"bayes_3a\": [(risk_errors[\"bayes_3a\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"bayes_3b\": [(risk_errors[\"bayes_3b\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_1_1\": [(risk_errors[\"excess_1_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_2_1\": [(risk_errors[\"excess_2_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_3a_1\": [(risk_errors[\"excess_3a_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_3b_1\": [(risk_errors[\"excess_3b_1\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))],\n",
    "            \"excess_3a_2\": [(risk_errors[\"excess_3a_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                          if sr not in [\"log\", \"mse\"] else [float('nan')] * len(method_seeds),\n",
    "            \"excess_3b_2\": [(risk_errors[\"excess_3b_2\"].mean(dim=(0))[i] - oracle_area[i]) /\n",
    "                          (random_area[i] - oracle_area[i]) for i in range(len(method_seeds))]\n",
    "                          if sr not in [\"log\", \"mse\"] else [float('nan')] * len(method_seeds),\n",
    "        }\n",
    "\n",
    "        if all_areas[sr] == {}:\n",
    "            for k, v in areas.items():\n",
    "                all_areas[sr][k] = [v]\n",
    "        else: # append new areas\n",
    "            for k, v in areas.items():\n",
    "                all_areas[sr][k] = all_areas[sr][k] + [v]\n",
    "\n",
    "    # stack to tensor\n",
    "    all_areas[sr] = {k: torch.tensor(v) for k, v in all_areas[sr].items()}\n",
    "\n",
    "    print(f\" {sr_labels[sr]} & \" + \" & \".join([\"\\\\valvar{\"f\"{torch.mean(v, dim=(0,1)):.3f}\" + \"}{\" + f\"{torch.mean(v, dim=0).std(dim=0):.3f}\".lstrip(\"0\") + \"}\" if not torch.isnan(v).any() else \"-\" for v in all_areas[sr].values()]) + \" \\\\\\\\\")\n",
    "\n",
    "print(\"\\\\bottomrule\")\n",
    "print(\"\\\\end{tabular}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
