{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7457490b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from itertools import product\n",
    "\n",
    "import torch\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from source.constants import RESULTS_PATH_AL\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad2f009b",
   "metadata": {},
   "outputs": [],
   "source": [
    "LABELS = {\n",
    "    \"total_1_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{1,1}$\",\n",
    "    \"total_2_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{2,1}$\",\n",
    "    \"total_3a_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,1}$\",\n",
    "    \"total_3b_1\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,1}$\",\n",
    "    \"total_3a_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3a,2}$\",\n",
    "    \"total_3b_2\": \"$\\\\hat{R}_{\\\\text{Tot}}^{3b,2}$\",\n",
    "    \"bayes_1\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{1}$\",\n",
    "    \"bayes_2\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{2}$\",\n",
    "    \"bayes_3a\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3a}$\",\n",
    "    \"bayes_3b\": \"$\\\\hat{R}_{\\\\text{Bayes}}^{3b}$\",\n",
    "    \"excess_1_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{1,1}$\",\n",
    "    \"excess_2_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{2,1}$\",\n",
    "    \"excess_3a_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,1}$\",\n",
    "    \"excess_3b_1\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,1}$\",\n",
    "    \"excess_3a_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3a,2}$\",\n",
    "    \"excess_3b_2\": \"$\\\\hat{R}_{\\\\text{Exc}}^{3b,2}$\",\n",
    "    \"random\": \"Random\",\n",
    "}\n",
    "\n",
    "sr_labels = {\n",
    "            \"crps\": \"CRPS\",\n",
    "            \"log\": \"Log\",\n",
    "            \"mse\": \"SE\",\n",
    "            \"quadratic\": \"Quad.\",\n",
    "}\n",
    "\n",
    "ACQ_ORDER = list(LABELS.keys())\n",
    "\n",
    "scoring_rules = [\"crps\", \"log\", \"mse\", \"quadratic\"]\n",
    "seeds = [42, 142, 242, 342, 442]\n",
    "datasets = ['ymsd','sgemm', 'ccpp', 'casp','news', 'blog']\n",
    "method = \"ensemble\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3a7a9d77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{c|cccccc|cccc|cccccc|c}\n",
      "\\toprule\n",
      " SR & $\\hat{R}_{\\text{Tot}}^{1,1}$ & $\\hat{R}_{\\text{Tot}}^{2,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,1}$ & $\\hat{R}_{\\text{Tot}}^{3b,1}$ & $\\hat{R}_{\\text{Tot}}^{3a,2}$ & $\\hat{R}_{\\text{Tot}}^{3b,2}$ & $\\hat{R}_{\\text{Bayes}}^{1}$ & $\\hat{R}_{\\text{Bayes}}^{2}$ & $\\hat{R}_{\\text{Bayes}}^{3a}$ & $\\hat{R}_{\\text{Bayes}}^{3b}$ & $\\hat{R}_{\\text{Exc}}^{1,1}$ & $\\hat{R}_{\\text{Exc}}^{2,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,1}$ & $\\hat{R}_{\\text{Exc}}^{3b,1}$ & $\\hat{R}_{\\text{Exc}}^{3a,2}$ & $\\hat{R}_{\\text{Exc}}^{3b,2}$ & Random \\\\\n",
      "\\midrule\n",
      "CRPS & \\valvar{16.60}{3.57} & \\valvar{16.60}{3.57} & \\valvar{16.57}{4.31} & \\valvar{17.83}{3.26} & \\valvar{17.60}{4.19} & \\valvar{18.27}{2.75} & \\valvar{19.07}{4.38} & \\valvar{16.60}{2.75} & \\valvar{17.70}{3.64} & \\valvar{19.20}{3.04} & \\valvar{17.20}{3.58} & \\valvar{17.20}{3.58} & \\valvar{15.90}{3.68} & \\valvar{15.07}{4.01} & \\valvar{15.27}{3.00} & \\valvar{15.87}{3.53} & \\valvar{20.50}{5.89} \\\\\n",
      "Log & \\valvar{15.73}{2.63} & \\valvar{15.73}{2.63} & \\valvar{14.97}{2.38} & \\valvar{15.47}{3.60} & - & - & \\valvar{19.13}{2.20} & - & \\valvar{17.97}{3.61} & \\valvar{21.47}{2.55} & \\valvar{13.90}{1.33} & \\valvar{13.90}{1.33} & \\valvar{17.33}{5.97} & \\valvar{14.37}{3.93} & - & - & \\valvar{20.50}{5.89} \\\\\n",
      "MSE & \\valvar{16.90}{4.80} & \\valvar{16.90}{4.80} & \\valvar{16.90}{4.80} & \\valvar{16.87}{5.26} & \\valvar{16.67}{5.03} & \\valvar{16.47}{2.91} & \\valvar{16.47}{2.91} & \\valvar{16.67}{5.03} & \\valvar{16.67}{5.03} & \\valvar{16.47}{2.91} & \\valvar{17.50}{3.67} & \\valvar{17.50}{3.67} & \\valvar{16.00}{5.21} & \\valvar{16.00}{5.21} & - & - & \\valvar{20.50}{5.89} \\\\\n",
      "Quad. & \\valvar{29.23}{2.06} & \\valvar{29.23}{2.06} & \\valvar{32.57}{2.34} & \\valvar{29.47}{5.12} & \\valvar{31.70}{1.93} & \\valvar{31.27}{4.23} & \\valvar{31.07}{3.07} & \\valvar{32.47}{1.10} & \\valvar{30.33}{3.81} & \\valvar{31.83}{4.17} & \\valvar{16.10}{3.37} & \\valvar{16.10}{3.37} & \\valvar{12.60}{2.19} & \\valvar{13.43}{3.93} & \\valvar{17.00}{2.34} & \\valvar{15.90}{2.46} & \\valvar{20.50}{5.89} \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n"
     ]
    }
   ],
   "source": [
    "# print latex table header\n",
    "print(\"\\\\begin{tabular}{c|cccccc|cccc|cccccc|c}\")\n",
    "print(\"\\\\toprule\")\n",
    "print(\" SR & \" + \" & \".join(LABELS.values()) + \" \\\\\\\\\")\n",
    "print(\"\\\\midrule\")\n",
    "\n",
    "\n",
    "all_ranks = []\n",
    "for dataset in datasets:\n",
    "    all_perfs = []\n",
    "    for sr in scoring_rules:\n",
    "        perfs = torch.load(os.path.join(RESULTS_PATH_AL, \"PERFS\", f\"nll_{dataset}_{method}_{sr}_{'_'.join(map(str, seeds))}.pt\"), weights_only=False)\n",
    "        acqs = torch.load(os.path.join(RESULTS_PATH_AL, \"PERFS\", f\"acqs_{dataset}_{method}_{sr}_{'_'.join(map(str, seeds))}.pt\"), weights_only=False)  \n",
    "\n",
    "        for idx, acq in enumerate(acqs):\n",
    "            if not perfs[idx].mean(dim=(0, 1)).isfinite():\n",
    "                # remove infinite values in tensor\n",
    "                perfs[idx][perfs[idx].isinf()] = torch.nan\n",
    "        \n",
    "        perfs_full = []\n",
    "        for idx, lab in enumerate(LABELS.keys()):\n",
    "            if lab not in acqs:\n",
    "                # add missing acquisition methods\n",
    "                perfs_full.append(torch.zeros_like(perfs[0]) * torch.nan)\n",
    "            else:\n",
    "                perfs_full.append(perfs[acqs.index(lab)])\n",
    "        perfs_full = torch.stack(perfs_full, dim=0)  # ( n_acq, n_seeds, n_iterations)\n",
    "\n",
    "        perfs_full = perfs_full.nanmean(dim=-1) # average over acquisition iterations\n",
    "\n",
    "        all_perfs.append(perfs_full)\n",
    "    all_perfs = torch.stack(all_perfs, dim=0)  # (n_scoring_rules, n_acq, n_seeds)\n",
    "\n",
    "\n",
    "\n",
    "    # compute ranks\n",
    "    ranks = torch.zeros_like(all_perfs) * torch.nan\n",
    "    for s in range(all_perfs.shape[2]):\n",
    "        c_min, c_rank = torch.min(all_perfs[..., s][~all_perfs[..., s].isnan()]), 1\n",
    "        for _ in range(all_perfs.shape[0] * all_perfs.shape[1]):\n",
    "            for i, j in product(range(all_perfs.shape[0]), range(all_perfs.shape[1])):\n",
    "                if all_perfs[i, j, s] == c_min:\n",
    "                    ranks[i, j, s] = c_rank\n",
    "                    all_perfs[i, j, s] = torch.nan  # remove from further consideration\n",
    "            c_rank += 1\n",
    "            # check if there are any non-nan values left\n",
    "            if torch.all(all_perfs[..., s].isnan()):\n",
    "                break\n",
    "            c_min = torch.min(all_perfs[..., s][~all_perfs[..., s].isnan()])\n",
    "    all_ranks.append(ranks)\n",
    "\n",
    "all_ranks = torch.stack(all_ranks, dim=0)  # (n_datasets, n_scoring_rules, n_acq, n_seeds)\n",
    "mean_ranks = torch.mean(all_ranks, dim=0)\n",
    "std_ranks = torch.std(mean_ranks, dim=-1)\n",
    "mean_ranks = torch.mean(mean_ranks, dim=-1)\n",
    "\n",
    "for i, sr in enumerate(scoring_rules):\n",
    "    print(sr_labels[sr], end=\" & \")\n",
    "    for j, acq in enumerate(ACQ_ORDER):\n",
    "        if not mean_ranks[i, j].isnan():\n",
    "            print(f\"\\\\valvar{{{mean_ranks[i, j]:.2f}}}{{{std_ranks[i, j]:.2f}}}\", end=\" & \" if j < len(ACQ_ORDER) - 1 else \" \\\\\\\\\\n\")\n",
    "        else:\n",
    "            print(\"-\", end=\" & \" if j < len(ACQ_ORDER) - 1 else \" \\\\\\\\\\n\")\n",
    "\n",
    "print(\"\\\\bottomrule\")\n",
    "print(\"\\\\end{tabular}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
