{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e9ae672d",
   "metadata": {},
   "source": [
    "# Query Trigger Strategies (Uncertainty Methods) - Figures\n",
    "\n",
    "Ablation study on different uncertainty/query trigger methods."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81427115",
   "metadata": {},
   "source": [
    "## Single Dataset - All Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a50b55f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# --- Parameters to adjust ---\n",
    "dataset  = \"carrot-bowl\"\n",
    "T        = 2000\n",
    "num_runs = 20\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.size': 14,\n",
    "    'axes.titlesize': 16,\n",
    "    'axes.labelsize': 14,\n",
    "    'xtick.labelsize': 12,\n",
    "    'ytick.labelsize': 12,\n",
    "    'legend.fontsize': 12\n",
    "})\n",
    "\n",
    "# --- Chemin du fichier pickle ---\n",
    "data_path = f\"../experiments/results/query_trigger_strategies/data/raw_data_uncertainty_methods_{dataset}_{T}_{num_runs}runs.pkl\"\n",
    "\n",
    "# --- Loading data ---\n",
    "with open(data_path, \"rb\") as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "all_OtB       = data[\"all_o2b\"]\n",
    "all_opr       = data[\"all_opr\"]\n",
    "budgets_accum = data[\"budgets_accum\"]\n",
    "\n",
    "# --- Computing averages ---\n",
    "avg_OtB = {a: np.mean(np.stack(all_OtB[a]), axis=0) for a in all_OtB}\n",
    "avg_opr = {a: np.mean(np.stack(all_opr[a]), axis=0) for a in all_opr}\n",
    "avg_bud = {a: np.mean(np.stack(budgets_accum[a]), axis=0) for a in budgets_accum}\n",
    "\n",
    "# --- Styles de tracé ---\n",
    "styles = {\n",
    "    \"optimal\":    {\"linestyle\": \"-.\", \"color\": \"green\"},\n",
    "    \"always\":     {\"linestyle\": \"--\", \"color\": \"blue\"},\n",
    "    \"random\":     {\"linestyle\": \":\",  \"color\": \"orange\"},\n",
    "    \"Delta\":      {\"linestyle\": \"-\",  \"color\": \"red\"},\n",
    "    \"Warm-start\": {\"linestyle\": \"--\", \"color\": \"purple\"},\n",
    "    \"UCB\":        {\"linestyle\": \"-.\", \"color\": \"brown\"},\n",
    "    \"No AL\":      {\"linestyle\": \"-\",  \"color\": \"black\"},\n",
    "    \"NN-Var\":     {\"linestyle\": \"-\",  \"color\": \"green\"},\n",
    "}\n",
    "\n",
    "# --- Tracé des 4 subplots ---\n",
    "fig, axes = plt.subplots(1, 4, figsize=(24, 6))\n",
    "window = T // 10\n",
    "\n",
    "# 1) Cumulative Regret\n",
    "ax = axes[0]\n",
    "for alg, series in avg_OtB.items():\n",
    "    if alg != \"optimal\":\n",
    "        regret = np.cumsum(avg_OtB[\"optimal\"] - series)\n",
    "        ax.plot(np.arange(1, len(regret)+1), regret, label=alg, **styles[alg])\n",
    "ax.set_title(\"Cumulative Regret\")\n",
    "ax.set_xlabel(\"Iteration\")\n",
    "ax.set_ylabel(\"Regret\")\n",
    "ax.legend()\n",
    "ax.grid(True)\n",
    "\n",
    "# 2) Sliding-window Avg OPR\n",
    "ax = axes[1]\n",
    "for alg, series in avg_opr.items():\n",
    "    if len(series) >= window:\n",
    "        mov = np.convolve(series, np.ones(window)/window, mode=\"valid\")\n",
    "        idx = np.linspace(0, len(mov)-1, 100, dtype=int)\n",
    "        ax.plot(np.arange(window, window+len(mov))[idx], mov[idx], label=alg, **styles[alg])\n",
    "ax.set_title(f\"{window}-Sliding Avg OPR\")\n",
    "ax.set_xlabel(\"Iteration\")\n",
    "ax.set_ylabel(\"Avg OPR\")\n",
    "ax.legend()\n",
    "ax.grid(True)\n",
    "\n",
    "# 3) Budget Consumption\n",
    "ax = axes[2]\n",
    "for alg, series in avg_bud.items():\n",
    "    ax.plot(np.arange(1, len(series)+1), series, label=alg, **styles[alg])\n",
    "ax.set_title(\"Budget Consumption\")\n",
    "ax.set_xlabel(\"Iteration\")\n",
    "ax.set_ylabel(\"Requêtes GT\")\n",
    "ax.legend()\n",
    "ax.grid(True)\n",
    "\n",
    "# 4) Sliding-window Avg OtB\n",
    "ax = axes[3]\n",
    "for alg, series in avg_OtB.items():\n",
    "    if alg != \"optimal\" and len(series) >= window:\n",
    "        mov = np.convolve(series, np.ones(window)/window, mode=\"valid\")\n",
    "        idx = np.linspace(0, len(mov)-1, 100, dtype=int)\n",
    "        ax.plot(np.arange(window, window+len(mov))[idx], mov[idx], label=alg, **styles[alg])\n",
    "ax.set_title(f\"{window}-Sliding Avg OtB\")\n",
    "ax.set_xlabel(\"Iteration\")\n",
    "ax.set_ylabel(\"Avg OtB\")\n",
    "ax.legend()\n",
    "ax.grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "os.makedirs(\"plots/query_trigger_strategies\", exist_ok=True)\n",
    "plt.savefig(f\"plots/query_trigger_strategies/{dataset}_{T}_{num_runs}runs.pdf\", dpi=600)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0e26103",
   "metadata": {},
   "source": [
    "## OPR and OtB Side-by-Side (Paper Plot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6399bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# --- Parameters ---\n",
    "dataset  = \"carrot-bowl\"\n",
    "T        = 2000\n",
    "num_runs = 20\n",
    "\n",
    "# --- Style global ---\n",
    "plt.rcParams.update({\n",
    "    'font.size': 26,\n",
    "    'axes.titlesize': 26,\n",
    "    'axes.labelsize': 24,\n",
    "    'xtick.labelsize': 16,\n",
    "    'ytick.labelsize': 16,\n",
    "    'legend.fontsize': 22\n",
    "})\n",
    "\n",
    "# --- Chemin du fichier pickle ---\n",
    "data_path = f\"../experiments/results/query_trigger_strategies/data/raw_data_uncertainty_methods_{dataset}_{T}_{num_runs}runs.pkl\"\n",
    "\n",
    "# --- Loading ---\n",
    "with open(data_path, \"rb\") as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "all_OtB = data[\"all_o2b\"]\n",
    "all_opr = data[\"all_opr\"]\n",
    "\n",
    "# --- Moyennes ---\n",
    "avg_OtB = {a: np.mean(np.stack(all_OtB[a]), axis=0) for a in all_OtB if len(all_OtB[a]) > 0}\n",
    "avg_opr = {a: np.mean(np.stack(all_opr[a]), axis=0) for a in all_opr if len(all_opr[a]) > 0}\n",
    "\n",
    "# --- Styles ---\n",
    "styles = {\n",
    "    \"optimal\":    {\"linestyle\": \"-.\", \"color\": \"green\",  \"linewidth\": 2.0},\n",
    "    \"always\":     {\"linestyle\": \"--\", \"color\": \"blue\",   \"linewidth\": 1.8},\n",
    "    \"random\":     {\"linestyle\": \":\",  \"color\": \"orange\", \"linewidth\": 1.8},\n",
    "    \"Delta\":      {\"linestyle\": \"-\",  \"color\": \"red\",    \"linewidth\": 2.0},\n",
    "    \"Warm-start\": {\"linestyle\": \"--\", \"color\": \"purple\", \"linewidth\": 1.8},\n",
    "    \"UCB\":        {\"linestyle\": \"-.\", \"color\": \"brown\",  \"linewidth\": 1.8},\n",
    "    \"No AL\":      {\"linestyle\": \"-\",  \"color\": \"black\",  \"linewidth\": 1.8},\n",
    "    \"NN-Var\":     {\"linestyle\": \"-\",  \"color\": \"darkgreen\",\"linewidth\": 1.8},\n",
    "}\n",
    "\n",
    "# Noms à afficher\n",
    "display_names = {\n",
    "    \"optimal\": \"Optimal\",\n",
    "    \"always\": \"Always\",\n",
    "    \"random\": \"Random\",\n",
    "    \"Delta\": \"Delta\",\n",
    "    \"Warm-start\": \"Warm-start\",\n",
    "    \"UCB\": \"UCB\",\n",
    "    \"No AL\": \"No AL\",\n",
    "    \"NN-Var\": \"NN-Var\",\n",
    "}\n",
    "\n",
    "def annotate_under_axis(ax, label):\n",
    "    ax.annotate(label, xy=(0.5, -0.18), xycoords='axes fraction',\n",
    "                ha='center', va='top', fontsize=34)\n",
    "\n",
    "# --- Figure ---\n",
    "fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharey=False)\n",
    "window = T // 10\n",
    "\n",
    "legend_handles, legend_labels = [], []\n",
    "\n",
    "# 1) OPR\n",
    "ax = axes[0]\n",
    "idx = None\n",
    "for alg, series in avg_opr.items():\n",
    "    if len(series) >= window:\n",
    "        mov = np.convolve(series, np.ones(window)/window, mode=\"valid\")\n",
    "        if idx is None:\n",
    "            idx = np.linspace(0, len(mov) - 1, 100, dtype=int)\n",
    "        x_vals = np.arange(window, window + len(mov))[idx]\n",
    "        ln, = ax.plot(x_vals, mov[idx], label=display_names.get(alg, alg), **styles.get(alg, {}))\n",
    "        legend_handles.append(ln)\n",
    "        legend_labels.append(display_names.get(alg, alg))\n",
    "\n",
    "ax.set_title(f\"{window}-Sliding Avg OPR\")\n",
    "ax.grid(True)\n",
    "annotate_under_axis(ax, \" \")\n",
    "\n",
    "# 2) OtB\n",
    "ax = axes[1]\n",
    "idx2 = None\n",
    "for alg, series in avg_OtB.items():\n",
    "    if alg != \"optimal\" and len(series) >= window:\n",
    "        mov = np.convolve(series, np.ones(window)/window, mode=\"valid\")\n",
    "        if idx2 is None:\n",
    "            idx2 = np.linspace(0, len(mov) - 1, 100, dtype=int)\n",
    "        x_vals = np.arange(window, window + len(mov))[idx2]\n",
    "        ax.plot(x_vals, mov[idx2], label=display_names.get(alg, alg), **styles.get(alg, {}))\n",
    "\n",
    "ax.set_title(f\"{window}-Sliding Avg OtB\")\n",
    "ax.grid(True)\n",
    "annotate_under_axis(ax, \" \")\n",
    "\n",
    "# Label Y global\n",
    "fig.supylabel(\"Score\", x=0.02, y=0.53, fontsize=24)\n",
    "\n",
    "# Ajustement des marges\n",
    "fig.subplots_adjust(left=0.09, right=0.99, bottom=0.22, top=0.88, wspace=0.08)\n",
    "\n",
    "# Légende\n",
    "fig.legend(\n",
    "    legend_handles, legend_labels,\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.5, 0.19),\n",
    "    ncol=len(legend_labels),\n",
    "    frameon=False,\n",
    "    handlelength=2.4, columnspacing=1.0, handletextpad=0.6\n",
    ")\n",
    "\n",
    "# Save\n",
    "os.makedirs(\"plots/query_trigger_strategies\", exist_ok=True)\n",
    "out_path = f\"plots/query_trigger_strategies/{dataset}_{T}_{num_runs}runs_OPR_OtB.pdf\"\n",
    "plt.savefig(out_path, dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "\n",
    "print(\"Saved to:\", out_path)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
