{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a30ffccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import torch\n",
    "\n",
    "from source.utils.uncertainty_measures import (\n",
    "        calculate_uncertainties_crps,\n",
    "        calculate_uncertainties_log,\n",
    "        calculate_uncertainties_mse,\n",
    "        calculate_uncertainties_quadratic\n",
    ")\n",
    "\n",
    "from source.constants import RESULTS_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "baa77f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"git_qood\"\n",
    "scoring_rules = [\"crps\", \"log\", \"mse\", \"quadratic\"]\n",
    "method_seeds = [42, 142, 242]\n",
    "\n",
    "method = \"ensemble\"\n",
    "anti_regularization_weight = 0.0\n",
    "use_natural = True\n",
    "network = \"cnn\"\n",
    "\n",
    "subset_picker = [\n",
    "    ['test', 'cifar'],\n",
    "    ['test', 'svhn'],\n",
    "    ['test', 'fashion'],\n",
    "    ['test', 'mixture'],\n",
    "    ['test', 'emnist_0'],\n",
    "    ['test', 'emnist_1'],\n",
    "    ['test', 'emnist_2'],\n",
    "    ['test', 'emnist_3']\n",
    "]\n",
    "\n",
    "not_considered = [\n",
    "    (\"log\", \"total_3a_2\"),\n",
    "    (\"log\", \"total_3b_2\"),\n",
    "    (\"log\", \"bayes_2\"),\n",
    "    (\"log\", \"excess_3a_2\"),\n",
    "    (\"log\", \"excess_3b_2\"),\n",
    "    (\"mse\", \"excess_3a_2\"),\n",
    "    (\"mse\", \"excess_3b_2\"),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eafb87a",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = {\n",
    "            \"total_1_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{1,1}$\",\n",
    "            \"total_2_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{2,1}$\",\n",
    "            \"total_3a_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,1}$\",\n",
    "            \"total_3b_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,1}$\",\n",
    "            \"total_3a_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,2}$\",\n",
    "            \"total_3b_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,2}$\",\n",
    "            \"bayes_1\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{1}$\",\n",
    "            \"bayes_2\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{2}$\",\n",
    "            \"bayes_3a\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3a}$\",\n",
    "            \"bayes_3b\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3b}$\",\n",
    "            \"excess_1_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{1,1}$\",\n",
    "            \"excess_2_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{2,1}$\",\n",
    "            \"excess_3a_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,1}$\",\n",
    "            \"excess_3b_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,1}$\",\n",
    "            \"excess_3a_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,2}$\",\n",
    "            \"excess_3b_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,2}$\",\n",
    "        }\n",
    "\n",
    "sr_labels = {\n",
    "            \"crps\": \"CRPS\",\n",
    "            \"log\": \"Log\",\n",
    "            \"mse\": \"SE\",\n",
    "            \"quadratic\": \"Quad.\",\n",
    "}\n",
    "\n",
    "subset_labels = {\n",
    "    \"cifar\": \"CIFAR-10\",\n",
    "    \"svhn\": \"SVHN\",\n",
    "    \"fashion\": \"Fashion-MNIST\",\n",
    "    \"mixture\": \"Mixture (C,S,F)\",\n",
    "    \"emnist_0\": \"EMNIST (tl)\",\n",
    "    \"emnist_1\": \"EMNIST (tr)\",\n",
    "    \"emnist_2\": \"EMNIST (bl)\",\n",
    "    \"emnist_3\": \"EMNIST (br)\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f23d1f72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{c|cccccc|cccc|cccccc}\n",
      "\\toprule\n",
      "SR & $\\hat{R}_{\\text{Tot}}^{1,1}$ & $\\hat{R}_{\\text{Tot}}^{2,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,1}$ & $\\hat{R}_{\\text{Tot}}^{3b,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,2}$ & $\\hat{R}_{\\text{Tot}}^{3b,2}$ & $\\hat{R}_{\\text{Bayes}}^{1}$ & $\\hat{R}_{\\text{Bayes}}^{2}$ & $\\hat{R}_{\\text{Bayes}}^{3a}$ & $\\hat{R}_{\\text{Bayes}}^{3b}$ & $\\hat{R}_{\\text{Exc}}^{1,1}$ & $\\hat{R}_{\\text{Exc}}^{2,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,1}$ & $\\hat{R}_{\\text{Exc}}^{3b,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,2}$ & $\\hat{R}_{\\text{Exc}}^{3b,2}$ \\\\\n",
      "\\midrule\n",
      "CRPS & \\valvar{0.794}{.012} & \\valvar{0.794}{.012} & \\valvar{0.794}{.012} & \\valvar{0.779}{.013} & \\valvar{0.777}{.012} & \\valvar{0.727}{.009} & \\valvar{0.714}{.010} & \\valvar{0.777}{.012} & \\valvar{0.777}{.012} & \\valvar{0.714}{.010} & \\valvar{0.825}{.010} & \\valvar{0.825}{.010} & \\valvar{0.825}{.010} & \\valvar{0.825}{.010} & \\valvar{0.795}{.006} & \\valvar{0.826}{.010} \\\\\n",
      "Log & \\valvar{0.802}{.011} & \\valvar{0.802}{.011} & \\valvar{0.797}{.012} & \\valvar{0.785}{.013} & - & - & \\valvar{0.713}{.010} & - & \\valvar{0.777}{.012} & \\valvar{0.714}{.010} & \\valvar{0.827}{.011} & \\valvar{0.827}{.011} & \\valvar{0.826}{.011} & \\valvar{0.826}{.011} & - & - \\\\\n",
      "MSE & \\valvar{0.792}{.012} & \\valvar{0.792}{.012} & \\valvar{0.792}{.012} & \\valvar{0.777}{.012} & \\valvar{0.777}{.012} & \\valvar{0.714}{.010} & \\valvar{0.714}{.010} & \\valvar{0.777}{.012} & \\valvar{0.777}{.012} & \\valvar{0.714}{.010} & \\valvar{0.820}{.010} & \\valvar{0.820}{.010} & \\valvar{0.820}{.010} & \\valvar{0.820}{.010} & - & - \\\\\n",
      "Quad. & \\valvar{0.800}{.012} & \\valvar{0.800}{.012} & \\valvar{0.800}{.012} & \\valvar{0.784}{.013} & \\valvar{0.777}{.013} & \\valvar{0.740}{.011} & \\valvar{0.713}{.010} & \\valvar{0.777}{.012} & \\valvar{0.777}{.012} & \\valvar{0.714}{.010} & \\valvar{0.816}{.016} & \\valvar{0.816}{.016} & \\valvar{0.816}{.016} & \\valvar{0.817}{.015} & \\valvar{0.783}{.005} & \\valvar{0.822}{.012} \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n"
     ]
    }
   ],
   "source": [
    "# print latex table header\n",
    "print(\"\\\\begin{tabular}{c|cccccc|cccc|cccccc}\")\n",
    "print(\"\\\\toprule\")\n",
    "print(\"SR & \" + \" & \".join(labels.values()) + \" \\\\\\\\\")\n",
    "print(\"\\\\midrule\")\n",
    "\n",
    "rocs_overall = np.zeros((len(subset_picker), len(scoring_rules), len(labels), len(method_seeds)))\n",
    "for i, subsets in enumerate(subset_picker):\n",
    "    for j, sr in enumerate(scoring_rules):\n",
    "\n",
    "        means, variances = [], []\n",
    "        ood_identity = []\n",
    "        for method_seed in method_seeds:\n",
    "            means_seed, variances_seed = [], []\n",
    "            ood_identity_seed = []\n",
    "            for subset in subsets:\n",
    "                preds = torch.load(\n",
    "                    os.path.join(\n",
    "                        RESULTS_PATH,\n",
    "                        f\"{task}_{network}_natural{use_natural}_mseed{method_seed}_arw{anti_regularization_weight}\",\n",
    "                        f\"preds{subset}.pt\",\n",
    "                    ), weights_only=False\n",
    "                )\n",
    "                means_seed.append(preds[..., 0])\n",
    "                variances_seed.append(preds[..., 1])\n",
    "                ood_identity_seed.append(torch.ones(size=(len(preds), )) if subset != 'test' else torch.zeros(size=(len(preds), )))\n",
    "            means.append(torch.cat(means_seed, dim=0))\n",
    "            variances.append(torch.cat(variances_seed, dim=0))\n",
    "            ood_identity.append(torch.cat(ood_identity_seed, dim=0))\n",
    "\n",
    "        means = torch.stack(means, dim=-2)\n",
    "        variances = torch.stack(variances, dim=-2)\n",
    "        ood_identity = torch.stack(ood_identity, dim=-1)\n",
    "\n",
    "        if sr == \"crps\":\n",
    "            uncertainties = calculate_uncertainties_crps(means, variances)\n",
    "        elif sr == \"log\":\n",
    "            uncertainties = calculate_uncertainties_log(means, variances)\n",
    "        elif sr == \"mse\":\n",
    "            uncertainties = calculate_uncertainties_mse(means, variances)\n",
    "        elif sr == \"quadratic\":\n",
    "            uncertainties = calculate_uncertainties_quadratic(means, variances)\n",
    "\n",
    "        rocs = {}\n",
    "        for unc in uncertainties.keys():\n",
    "            if sr in [e[0] for e in not_considered] and unc in [e[1] for e in not_considered if e[0] == sr]:\n",
    "                rocs[unc] = [float('nan')] * len(method_seeds)\n",
    "            else:\n",
    "                rocs[unc] = [roc_auc_score(ood_identity[:, i].numpy(), uncertainties[unc][:, i].numpy()) for i in range(len(method_seeds))]\n",
    "\n",
    "        # accumulate results for overall average\n",
    "        rocs_overall[i, j, :, :] = np.array([rocs[unc] for unc in labels.keys()])\n",
    "\n",
    "rocs_overall = np.mean(rocs_overall, axis=0)  # average over subsets\n",
    "\n",
    "for j, sr in enumerate(scoring_rules):\n",
    "    rocs = {unc: rocs_overall[j, k, :] for k, unc in enumerate(labels.keys())}\n",
    "    print(f\"{sr_labels[sr]} & \" + \" & \".join([\"\\\\valvar{\"f\"{np.mean(v):.3f}\" + \"}{\" + f\"{np.std(v):.3f}\".lstrip(\"0\") + \"}\" if not torch.isnan(torch.tensor(v)).any() else \"-\" for v in rocs.values()]) + \" \\\\\\\\\")\n",
    "\n",
    "print(\"\\\\bottomrule\")\n",
    "print(\"\\\\end{tabular}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dd2feeb",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
