{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7263fda6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "from typing import Final\n",
    "from tqdm import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "\n",
    "from source.utils.metrics import mse, nll, crps\n",
    "from source.constants import RESULTS_PATH_AL_RND, PLOTS_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a3ad87f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed36f6cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = {\n",
    "    \"random\": \"Random\",\n",
    "    \"total_1_1\": \"$\\\\hat{R}_{Tot}^{1,1}$\",\n",
    "    \"total_2_1\": \"$\\\\hat{R}_{Tot}^{2,1}$\",\n",
    "    \"total_3a_1\": \"$\\\\hat{R}_{Tot}^{3a,1}$\",\n",
    "    \"total_3b_1\": \"$\\\\hat{R}_{Tot}^{3b,1}$\",\n",
    "    \"total_3a_2\": \"$\\\\hat{R}_{Tot}^{3a,2}$\",\n",
    "    \"total_3b_2\": \"$\\\\hat{R}_{Tot}^{3b,2}$\",\n",
    "    \"bayes_1\": \"$\\\\hat{R}_{Bayes}^1$\",\n",
    "    \"bayes_2\": \"$\\\\hat{R}_{Bayes}^2$\",\n",
    "    \"bayes_3a\": \"$\\\\hat{R}_{Bayes}^{3a}$\",\n",
    "    \"bayes_3b\": \"$\\\\hat{R}_{Bayes}^{3b}$\",\n",
    "    \"excess_1_1\": \"$\\\\hat{R}_{Exc}^{1,1}$\",\n",
    "    \"excess_2_1\": \"$\\\\hat{R}_{Exc}^{2,1}$\",\n",
    "    \"excess_3a_1\": \"$\\\\hat{R}_{Exc}^{3a,1}$\",\n",
    "    \"excess_3b_1\": \"$\\\\hat{R}_{Exc}^{3b,1}$\",\n",
    "    \"excess_3a_2\": \"$\\\\hat{R}_{Exc}^{3a,2}$\",\n",
    "    \"excess_3b_2\": \"$\\\\hat{R}_{Exc}^{3b,2}$\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ef6a4bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"ensemble\"]\n",
    "method_names = [\"Deep Ensemble\"]\n",
    "selection = 0\n",
    "method = methods[selection]\n",
    "method_name = method_names[selection]\n",
    "\n",
    "seeds = [142, 143, 144, 145, 146]\n",
    "seeds = range(1, 11)\n",
    "\n",
    "acquisition_functions = [label for label in labels.keys()]\n",
    "linestyles = [\"-\"] * len(acquisition_functions)\n",
    "colors = [\"black\"]  \n",
    "colors += sns.color_palette(\"Blues\", n_colors=7)[1:]\n",
    "colors += sns.color_palette(\"Oranges\", n_colors=5)[1:]\n",
    "colors += sns.color_palette(\"Greens\", n_colors=7)[1:]\n",
    "\n",
    "datasets = [\"ymsd\", \"sgemm\", \"ccpp\", \"casp\", \"news\", \"blog\"]\n",
    "dataset = datasets[0]\n",
    "scoring_rules = [\"crps\", \"log\", \"mse\", \"quadratic\"]\n",
    "scoring_rules = [\"crps\", \"mse\"]\n",
    "scoring_rule = scoring_rules[1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4907400",
   "metadata": {},
   "outputs": [],
   "source": [
    "#os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00e422dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "records = []\n",
    "\n",
    "for seed, dataset, scoring_rule, af in itertools.product(\n",
    "        seeds, datasets, scoring_rules, acquisition_functions):\n",
    "    sr = \"crps\" if af == \"random\" else scoring_rule\n",
    "    files = glob.glob(\n",
    "        os.path.join(\"./results_al_rnd3\",\n",
    "                     f\"{dataset}_{method}_{sr}_{af}_seed{seed}_*\"))\n",
    "    files.sort()\n",
    "\n",
    "    # get number of iterations\n",
    "    n_iterations = 30\n",
    "    for i in tqdm(range(n_iterations)):\n",
    "        mse_fn = os.path.join(files[-1], f\"avg_mse_{i}.npy\")\n",
    "        nll_fn = os.path.join(files[-1], f\"avg_nll_{i}.npy\")\n",
    "        try:\n",
    "            mse_i = np.load(mse_fn).item()\n",
    "            nll_i = np.load(nll_fn).item()\n",
    "            rec = dict(seed=seed,\n",
    "                       dataset=dataset,\n",
    "                       scoring_rule=scoring_rule,\n",
    "                       acquisition_function=af,\n",
    "                       iteration=i,\n",
    "                       mse=mse_i,\n",
    "                       nll=nll_i)\n",
    "            records.append(rec)\n",
    "        except:\n",
    "            #print(\n",
    "            #    f\"Configuration {dataset} {method} {scoring_rule} {af} iteration {i} not fully computed\"\n",
    "            #)\n",
    "            #print(f\"file {mse_fn} not found!\")\n",
    "            break\n",
    "df = pd.DataFrame(records)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0c0af49",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdb8abaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07a42bf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import colorcet as cc\n",
    "\n",
    "\n",
    "palette = dict(\n",
    "    zip(\n",
    "        list(labels.values()) + [\"Random\"],\n",
    "        sns.color_palette(cc.glasbey,\n",
    "                          n_colors=len(acquisition_functions) + 1)\n",
    "    )\n",
    ")\n",
    "palette[\"Random\"] = \"k\"\n",
    "df[\"Is random?\"] = df[\"acquisition_function\"] == \"random\"\n",
    "df[\"Risk estimate\"] = df[\"acquisition_function\"].replace(labels)\n",
    "\n",
    "rel = sns.relplot(df, x=\"iteration\", y=\"mse\", kind=\"line\",\n",
    "                  row=\"dataset\",\n",
    "                  col=\"scoring_rule\", hue=\"Risk estimate\",\n",
    "                  palette=palette, style=\"Is random?\", alpha=0.3,\n",
    "                  errorbar=None, estimator='median',\n",
    "                  facet_kws={'sharey': False, 'sharex': True})\n",
    "for ax in rel.figure.axes:\n",
    "    #print(ax.lines[0].get_label(), ax.lines[0].get_linestyle())\n",
    "    for line in ax.lines:\n",
    "        if line.get_linestyle() == \"--\":\n",
    "            line.set_alpha(1.0)\n",
    "            line.set_color('k')\n",
    "rel.figure.savefig(os.path.join(PLOTS_PATH, f\"active_learning_norm_mse_median.png\"), dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13f85317",
   "metadata": {},
   "outputs": [],
   "source": [
    "rel = sns.relplot(df, x=\"iteration\", y=\"nll\", kind=\"line\", \n",
    "                  row=\"dataset\",\n",
    "                  col=\"scoring_rule\", hue=\"Risk estimate\", \n",
    "                  palette=palette, style=\"Is random?\", alpha=0.3,\n",
    "                  errorbar=None, estimator='median',\n",
    "                  facet_kws={'sharey': False, 'sharex': True})\n",
    "for ax in rel.figure.axes:\n",
    "    #print(ax.lines[0].get_label(), ax.lines[0].get_linestyle())\n",
    "    for line in ax.lines:\n",
    "        if line.get_linestyle() == \"--\":\n",
    "            line.set_alpha(1.0)\n",
    "            line.set_color('k')\n",
    "rel.figure.savefig(os.path.join(PLOTS_PATH, f\"active_learning_norm_nll_median.png\"), dpi=300)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2c045e0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "reg_uncertainty",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
