{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dill\n",
    "with open(\"empirical_results/nats_sss_known_cost_metrics_per_acq_updated.pkl\", \"rb\") as f:\n",
    "    metrics_per_acq = dill.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = 'Times New Roman'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_acc_per_dataset = {\n",
    "    'cifar10-valid': 90.5,\n",
    "    'cifar100': 71.34,\n",
    "    'ImageNet16-120': 47.4,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# === User‐provided data and helper functions (assumed to exist) ===\n",
    "# metrics_per_acq[d][acq_key][\"estimated cumulative cost\"]\n",
    "# test_metrics_per_acq[d][acq_key][\"final test error\"]\n",
    "# bench.query_best_final(dataset, \"final_test_accuracy\", \"final_test_accuracy\", 0)\n",
    "\n",
    "# === Config ===\n",
    "dataset_names = ['cifar10-valid', 'cifar100', 'ImageNet16-120']\n",
    "lambdas       = [1e-4, 1e-5, 1e-6]\n",
    "acq_order     = ['LogEIPC', 'PBGI', 'LCB', 'TS']\n",
    "init          = 0\n",
    "\n",
    "\n",
    "cost_limit_per_dataset = {\n",
    "    \"cifar10-valid\":    240000,\n",
    "    \"cifar100\":         400000,\n",
    "    \"ImageNet16-120\":   400000\n",
    "}\n",
    "\n",
    "# === Color and marker settings ===\n",
    "color_dict = {\n",
    "    'LogEIPC':      'tab:blue',\n",
    "    'LogEIPC-med':  'tab:blue',\n",
    "    'PBGI(1e-4)':  'tab:orange',\n",
    "    'PBGI(1e-5)':  'tab:orange',\n",
    "    'PBGI(1e-6)':  'tab:orange',\n",
    "    'LCB':         'tab:purple',\n",
    "    'UCB-LCB':     'tab:purple',\n",
    "    'SRGap-med':   'tab:pink',\n",
    "    'TS':          'tab:brown',\n",
    "    'PRB':         'tab:brown',\n",
    "    'GSS':         'tab:olive',\n",
    "    'Convergence': 'tab:gray',\n",
    "    'Immediate':   'tab:cyan',\n",
    "    'Hindsight':   'tab:red'\n",
    "}\n",
    "\n",
    "# === Build stopping_rules for each λ ===\n",
    "stopping_rules = []\n",
    "for lam in lambdas:\n",
    "    lam_str   = f\"1e-{int(round(-np.log10(lam)))}\"  # “1e-4”, “1e-5”, “1e-6”\n",
    "    fixed_acq = f\"PBGI({lam_str})\"\n",
    "\n",
    "    templates = [\n",
    "        {\n",
    "            'stp_key':      'PBGI',\n",
    "            'is_hindsight': False,\n",
    "            'condition_fn': (lambda i, seed, d, lam=lam, fa=fixed_acq: (\n",
    "                i >= init and\n",
    "                metrics_per_acq[d][fa][f\"{fa} acq\"][seed][i] >=\n",
    "                metrics_per_acq[d][fa][\"current best observed\"][seed][i - 1]\n",
    "            ))\n",
    "        },\n",
    "        # {\n",
    "        #     'stp_key':      'LogEIPC',\n",
    "        #     'is_hindsight': False,\n",
    "        #     'condition_fn': (lambda i, seed, d, lam=lam, fa=fixed_acq: (\n",
    "        #         i >= init and\n",
    "        #         metrics_per_acq[d][fa][\"LogEIPC acq\"][seed][i] <= np.log(lam)\n",
    "        #     ))\n",
    "        # },\n",
    "        {\n",
    "            'stp_key':      'LogEIPC-med',\n",
    "            'is_hindsight': False,\n",
    "            'condition_fn': (lambda i, seed, d, fa=fixed_acq: (\n",
    "                i >= init and\n",
    "                metrics_per_acq[d][fa][\"LogEIPC acq\"][seed][i] <= (\n",
    "                    np.log(0.01) +\n",
    "                    np.nanmedian(metrics_per_acq[d][fa][\"LogEIPC acq\"][seed][1:21])\n",
    "                )\n",
    "            ))\n",
    "        },\n",
    "        {\n",
    "            'stp_key':      'SRGap-med',\n",
    "            'is_hindsight': False,\n",
    "            'condition_fn': (lambda i, seed, d, fa=fixed_acq: (\n",
    "                i >= init and\n",
    "                metrics_per_acq[d][fa][\"exp min regret gap\"][seed][i] <=\n",
    "                0.01 * np.nanmedian(\n",
    "                    metrics_per_acq[d][fa][\"exp min regret gap\"][seed][1:21]\n",
    "                )\n",
    "            ))\n",
    "        },\n",
    "        {\n",
    "            'stp_key':      'UCB-LCB',\n",
    "            'is_hindsight': False,\n",
    "            'condition_fn': (lambda i, seed, d, fa=fixed_acq: (\n",
    "                i >= init and\n",
    "                metrics_per_acq[d][fa][\"regret upper bound\"][seed][i] <= 0.01\n",
    "            ))\n",
    "        },\n",
    "        {\n",
    "            'stp_key':      'GSS',\n",
    "            'is_hindsight': False,\n",
    "            'condition_fn': (lambda i, seed, d, fa=fixed_acq: (\n",
    "                i >= init and (\n",
    "                    (np.nanpercentile(\n",
    "                        metrics_per_acq[d][fa][\"current best observed\"][seed][:i+1], 75\n",
    "                     ) -\n",
    "                     np.nanpercentile(\n",
    "                        metrics_per_acq[d][fa][\"current best observed\"][seed][:i+1], 25\n",
    "                     )) == 0\n",
    "                ) or (\n",
    "                    ((metrics_per_acq[d][fa][\"current best observed\"][seed][i - 5] -\n",
    "                      metrics_per_acq[d][fa][\"current best observed\"][seed][i]) /\n",
    "                     (np.nanpercentile(\n",
    "                        metrics_per_acq[d][fa][\"current best observed\"][seed][:i+1], 75\n",
    "                     ) -\n",
    "                      np.nanpercentile(\n",
    "                        metrics_per_acq[d][fa][\"current best observed\"][seed][:i+1], 25\n",
    "                     ))\n",
    "                    ) <= 0.01\n",
    "                )\n",
    "            ))\n",
    "        },\n",
    "        {\n",
    "            'stp_key':      'Convergence',\n",
    "            'is_hindsight': False,\n",
    "            'condition_fn': (lambda i, seed, d, fa=fixed_acq: (\n",
    "                i >= init and\n",
    "                metrics_per_acq[d][fa][\"current best observed\"][seed][i] ==\n",
    "                metrics_per_acq[d][fa][\"current best observed\"][seed][i - 5]\n",
    "            ))\n",
    "        },\n",
    "        {\n",
    "            'stp_key':      'PRB',\n",
    "            'is_hindsight': False,\n",
    "            'condition_fn': (lambda i, seed, d, fa=fixed_acq: (\n",
    "                i >= init and\n",
    "                metrics_per_acq[d][fa][\"PRB\"][seed][i] >= 0.95\n",
    "            ))\n",
    "        },\n",
    "        {\n",
    "            'stp_key':      'Immediate',\n",
    "            'is_hindsight': False,\n",
    "            'condition_fn': (lambda i, seed, d, fa=fixed_acq: (\n",
    "                i >= 0\n",
    "            ))\n",
    "        },\n",
    "        {\n",
    "            'stp_key':      'Hindsight',\n",
    "            'is_hindsight': True,\n",
    "            'condition_fn': None\n",
    "        }\n",
    "    ]\n",
    "\n",
    "    for temp in templates:\n",
    "        rule = {\n",
    "            'acq_key':      fixed_acq,\n",
    "            'stp_key':      temp['stp_key'],\n",
    "            'is_hindsight': temp['is_hindsight'],\n",
    "            'marker':       'x' if temp['stp_key'] == 'Hindsight' else 's',\n",
    "            'color':        color_dict.get(temp['stp_key'], color_dict[fixed_acq]),\n",
    "            'label':        'PBGI/LogEIPC' if temp['stp_key'] == 'PBGI' else temp['stp_key']\n",
    "        }\n",
    "        if temp['condition_fn'] is not None:\n",
    "            rule['condition_fn'] = temp['condition_fn']\n",
    "        stopping_rules.append(rule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from math import ceil\n",
    "\n",
    "# --- Plot style ---\n",
    "sns.set_style('whitegrid', {\n",
    "    'grid.linestyle': '--',\n",
    "    'grid.alpha': 0.4\n",
    "})\n",
    "\n",
    "plt.style.use('seaborn-v0_8-bright')\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = 'Times New Roman'\n",
    "plt.rcParams.update({\n",
    "    'font.size': 12,\n",
    "    'axes.titlesize': 14,\n",
    "    'axes.labelsize': 12,\n",
    "    'legend.fontsize': 10,\n",
    "    # 'xtick.rotation': 45,\n",
    "    'xtick.labelsize': 10,\n",
    "    'ytick.labelsize': 10,\n",
    "    'figure.autolayout': False,  # we’ll call tight_layout() explicitly\n",
    "})\n",
    "\n",
    "\n",
    "def style_spines(ax, color='black', linewidth=1):\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "        spine.set_edgecolor(color)\n",
    "        spine.set_linewidth(linewidth)\n",
    "\n",
    "# --- Config ---\n",
    "dataset_names = ['cifar10-valid', 'cifar100', 'ImageNet16-120']\n",
    "lam = 1e-5\n",
    "lam_str = '1e-5'\n",
    "acq_order = ['LogEIPC', 'PBGI', 'LCB', 'TS']\n",
    "init = 0\n",
    "n_rows, n_cols = 2, len(dataset_names)\n",
    "cost_keys = [\"estimated cumulative cost\", \"cumulative cost\"]\n",
    "runtime_labels = [\"w/ proxy runtime\", \"w/ actual runtime\"]\n",
    "\n",
    "# --- Create figure ---\n",
    "fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 4 * n_rows), sharey=True)\n",
    "\n",
    "for row in range(n_rows):\n",
    "    cost_key = cost_keys[row]\n",
    "    runtime_label = runtime_labels[row]\n",
    "\n",
    "    for col, d in enumerate(dataset_names):\n",
    "        ax = axes[row][col] if n_rows > 1 else axes[col]\n",
    "        best_acc = best_acc_per_dataset[d]\n",
    "        best_error = 100. - best_acc\n",
    "\n",
    "        # Collect Immediate values across all acquisition functions\n",
    "        immediate_vals = []\n",
    "        for j, acq in enumerate(acq_order):\n",
    "            fixed_acq = f\"PBGI({lam_str})\" if acq == 'PBGI' else acq\n",
    "            \n",
    "            # Find Immediate rule for this acquisition function\n",
    "            immediate_rule = None\n",
    "            for rule in stopping_rules:\n",
    "                if rule['acq_key'] == f\"PBGI({lam_str})\" and rule['stp_key'] == 'Immediate':\n",
    "                    immediate_rule = rule\n",
    "                    break\n",
    "            \n",
    "            if immediate_rule:\n",
    "                stop_vals = []\n",
    "                num_seeds = len(metrics_per_acq[d][fixed_acq][cost_key])\n",
    "                for seed in range(num_seeds):\n",
    "                    errs = metrics_per_acq[d][fixed_acq][\"final test error\"][seed]\n",
    "                    costs = metrics_per_acq[d][fixed_acq][cost_key][seed]\n",
    "                    # Immediate stops at iteration 0\n",
    "                    stop_idx = 0\n",
    "                    stop_vals.append((errs[stop_idx] - best_error) + lam * costs[stop_idx])\n",
    "                immediate_vals.extend(stop_vals)\n",
    "\n",
    "        for j, acq in enumerate(acq_order):\n",
    "            fixed_acq = f\"PBGI({lam_str})\" if acq == 'PBGI' else acq\n",
    "\n",
    "            for rule in stopping_rules:\n",
    "                if rule['acq_key'] != f\"PBGI({lam_str})\":\n",
    "                    continue\n",
    "                \n",
    "                # Skip Immediate rule here - we'll plot it separately\n",
    "                if rule['stp_key'] == 'Immediate':\n",
    "                    continue\n",
    "\n",
    "                stop_vals = []\n",
    "                for seed in range(len(metrics_per_acq[d][fixed_acq][cost_key])):\n",
    "                    errs = metrics_per_acq[d][fixed_acq][\"final test error\"][seed]\n",
    "                    costs = metrics_per_acq[d][fixed_acq][cost_key][seed]\n",
    "                    if rule['is_hindsight']:\n",
    "                        regs = np.array(errs) - best_error + lam * np.array(costs)\n",
    "                        idx = np.argmin(regs)\n",
    "                    else:\n",
    "                        idx = next((k for k in range(init, len(errs)) if rule['condition_fn'](k, seed, d)), len(errs) - 1)\n",
    "                    stop_vals.append((errs[idx] - best_error) + lam * costs[idx])\n",
    "\n",
    "                mean_val = np.mean(stop_vals)\n",
    "                err_val = np.std(stop_vals) / np.sqrt(len(stop_vals))\n",
    "\n",
    "                alpha = 0.6 if (rule['stp_key'] == 'PBGI' or rule['is_hindsight']) else 0.3\n",
    "                zorder = 3 if rule['stp_key'] == 'PBGI' else 2\n",
    "                marker = \"^\" if rule['is_hindsight'] else rule['marker']\n",
    "                ax.errorbar(j, mean_val, yerr=2 * err_val,\n",
    "                            fmt=marker,\n",
    "                            markersize=8,\n",
    "                            linewidth=2,\n",
    "                            color=rule['color'],\n",
    "                            capsize=0,\n",
    "                            alpha=alpha,\n",
    "                            zorder=zorder,\n",
    "                            label=rule['label'])\n",
    "        \n",
    "        # Plot Immediate as horizontal line with shaded error bars\n",
    "        if immediate_vals:\n",
    "            immediate_mean = np.mean(immediate_vals)\n",
    "            immediate_err = np.std(immediate_vals) / np.sqrt(len(immediate_vals))\n",
    "            \n",
    "            # Draw horizontal line across all acquisition functions\n",
    "            ax.axhline(y=immediate_mean, xmin=0, xmax=1, \n",
    "                      color='tab:cyan', linestyle='--', linewidth=2, alpha=0.6, zorder=1)\n",
    "            \n",
    "            # Add shaded error region\n",
    "            ax.fill_between([0, len(acq_order)-1], \n",
    "                           [immediate_mean - 2*immediate_err, immediate_mean - 2*immediate_err],\n",
    "                           [immediate_mean + 2*immediate_err, immediate_mean + 2*immediate_err],\n",
    "                           color='tab:cyan', alpha=0.2, zorder=0)\n",
    "            \n",
    "            # Add label only once\n",
    "            if row == 0 and col == 0:  # Only add label for the first subplot\n",
    "                ax.plot([], [], color='tab:cyan', linestyle='--', linewidth=2, \n",
    "                       alpha=0.6, label='Immediate')\n",
    "\n",
    "        ax.set_xticks(range(len(acq_order)))\n",
    "        ax.set_xticklabels(acq_order)\n",
    "        ax.grid(True, linestyle='--', alpha=0.4)\n",
    "        style_spines(ax)\n",
    "\n",
    "        if row == 0:\n",
    "            ax.set_title(d if d == \"cifar10-valid\" else d.capitalize())\n",
    "        if col == 0:\n",
    "            ax.set_ylabel(\"Cost-Adjusted Regret\")\n",
    "        if col == n_cols - 1:\n",
    "            ax.text(1.05, 0.5, runtime_label,\n",
    "                    transform=ax.transAxes,\n",
    "                    rotation=-90,\n",
    "                    va='center',\n",
    "                    ha='left',\n",
    "                    fontsize=12,\n",
    "                    color='black')\n",
    "\n",
    "# --- Shared legend ---\n",
    "handles, labels = axes[0][0].get_legend_handles_labels()\n",
    "legend_hls = dict(zip(labels, handles))\n",
    "\n",
    "# Move Immediate to the end if it exists\n",
    "if 'Immediate' in legend_hls:\n",
    "    immediate_handle = legend_hls.pop('Immediate')\n",
    "    legend_hls['Immediate'] = immediate_handle\n",
    "\n",
    "fig.legend(legend_hls.values(), legend_hls.keys(),\n",
    "           loc='lower center',\n",
    "           bbox_to_anchor=(0.5, -0.02),\n",
    "           ncol=ceil(len(legend_hls) / 2),\n",
    "           fontsize=12)\n",
    "\n",
    "plt.tight_layout(rect=[0, 0.05, 1, 0.95])\n",
    "plt.savefig('../plots/BarPlot_NATS_cost_mismatch.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from math import ceil\n",
    "\n",
    "# Set style\n",
    "sns.set_style('whitegrid', {\n",
    "    'grid.linestyle': '--',\n",
    "    'grid.alpha': 0.4\n",
    "})\n",
    "\n",
    "plt.style.use('seaborn-v0_8-bright')\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = 'Times New Roman'\n",
    "plt.rcParams.update({\n",
    "    'font.size': 12,\n",
    "    'axes.titlesize': 14,\n",
    "    'axes.labelsize': 12,\n",
    "    'legend.fontsize': 10,\n",
    "    # 'xtick.rotation': 45,\n",
    "    'xtick.labelsize': 10,\n",
    "    'ytick.labelsize': 10,\n",
    "    'figure.autolayout': False,  # we’ll call tight_layout() explicitly\n",
    "})\n",
    "\n",
    "\n",
    "def style_spines(ax, color='black', linewidth=1):\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "        spine.set_edgecolor(color)\n",
    "        spine.set_linewidth(linewidth)\n",
    "\n",
    "# Config\n",
    "dataset_names = ['cifar10-valid', 'cifar100', 'ImageNet16-120']\n",
    "lam = 1e-5\n",
    "lam_str = '1e-5'\n",
    "acq_order = ['LogEIPC', 'PBGI', 'LCB', 'TS']\n",
    "init = 0\n",
    "n_cols = len(acq_order)\n",
    "\n",
    "# Set up plot\n",
    "fig, axes = plt.subplots(1, len(dataset_names), figsize=(3 * len(dataset_names), 4 * 1), sharey=False)\n",
    "\n",
    "for col, d in enumerate(dataset_names):\n",
    "    ax = axes[col]\n",
    "    best_acc = best_acc_per_dataset[d]\n",
    "    best_error = 100. - best_acc\n",
    "\n",
    "    # Collect Immediate values across all acquisition functions\n",
    "    immediate_vals = []\n",
    "    for j, acq in enumerate(acq_order):\n",
    "        fixed_acq = f\"PBGI({lam_str})\" if acq == 'PBGI' else acq\n",
    "        \n",
    "        # Find Immediate rule for this acquisition function\n",
    "        immediate_rule = None\n",
    "        for rule in stopping_rules:\n",
    "            if rule['acq_key'] == f\"PBGI({lam_str})\" and rule['stp_key'] == 'Immediate':\n",
    "                immediate_rule = rule\n",
    "                break\n",
    "        \n",
    "        if immediate_rule:\n",
    "            stop_vals = []\n",
    "            num_seeds = len(metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"])\n",
    "            for seed in range(num_seeds):\n",
    "                errs = metrics_per_acq[d][fixed_acq][\"final test error\"][seed][:201]\n",
    "                costs = metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"][seed][:201]\n",
    "                # Immediate stops at iteration 0\n",
    "                stop_idx = 0\n",
    "                stop_vals.append((errs[stop_idx] - best_error) + lam * costs[stop_idx])\n",
    "            immediate_vals.extend(stop_vals)\n",
    "\n",
    "    for j, acq in enumerate(acq_order):\n",
    "        fixed_acq = f\"PBGI({lam_str})\" if acq == 'PBGI' else acq\n",
    "\n",
    "        for rule in stopping_rules:\n",
    "            if rule['acq_key'] != f\"PBGI({lam_str})\":\n",
    "                continue\n",
    "            \n",
    "            # Skip Immediate rule here - we'll plot it separately\n",
    "            if rule['stp_key'] == 'Immediate':\n",
    "                continue\n",
    "\n",
    "            stop_vals = []\n",
    "            stop_iters = []\n",
    "            non_stopping_count = 0  # <-- Track non-stopping seeds\n",
    "            num_seeds = len(metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"])\n",
    "            for seed in range(num_seeds):\n",
    "                errs = metrics_per_acq[d][fixed_acq][\"final test error\"][seed][:201]\n",
    "                costs = metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"][seed][:201]\n",
    "                if rule['is_hindsight']:\n",
    "                    regs = np.array(errs) - best_error + lam * np.array(costs)\n",
    "                    idx = np.argmin(regs)\n",
    "                else:\n",
    "                    idx = next((k for k in range(init, len(errs)) if rule['condition_fn'](k, seed, d)), len(errs) - 1)\n",
    "                    if idx == len(errs) - 1:\n",
    "                        non_stopping_count += 1  # <-- Count seeds that never stopped early\n",
    "                stop_vals.append((errs[idx] - best_error) + lam * costs[idx])\n",
    "\n",
    "            mean_val = np.mean(stop_vals)\n",
    "            err_val = np.std(stop_vals) / np.sqrt(len(stop_vals))\n",
    "\n",
    "            alpha = 0.6 if (rule['stp_key'] == 'PBGI' or rule['is_hindsight']) else 0.3\n",
    "            zorder = 3 if rule['stp_key'] == 'PBGI' else 2\n",
    "            marker = \"^\" if rule['is_hindsight'] else 's' \n",
    "            ax.errorbar(j, mean_val, yerr=2 * err_val,\n",
    "                        fmt=marker,\n",
    "                        markersize=8,\n",
    "                        linewidth=2,\n",
    "                        color=rule['color'],\n",
    "                        capsize=0,\n",
    "                        alpha=alpha,\n",
    "                        zorder=zorder,\n",
    "                        label=rule['label'])\n",
    "            \n",
    "            # --- Output non-stopping stats ---\n",
    "            print(f\"[{d}] Acquisition: {acq}, Rule: {rule['label']}, Non-stopping seeds: {non_stopping_count} / {num_seeds}\")\n",
    "    \n",
    "    # Plot Immediate as horizontal line with shaded error bars\n",
    "    if immediate_vals:\n",
    "        immediate_mean = np.mean(immediate_vals)\n",
    "        immediate_err = np.std(immediate_vals) / np.sqrt(len(immediate_vals))\n",
    "        \n",
    "        # Draw horizontal line across all acquisition functions\n",
    "        ax.axhline(y=immediate_mean, xmin=0, xmax=1, \n",
    "                  color='tab:cyan', linestyle='--', linewidth=2, alpha=0.6, zorder=1)\n",
    "        \n",
    "        # Add shaded error region\n",
    "        ax.fill_between([0, len(acq_order)-1], \n",
    "                       [immediate_mean - 2*immediate_err, immediate_mean - 2*immediate_err],\n",
    "                       [immediate_mean + 2*immediate_err, immediate_mean + 2*immediate_err],\n",
    "                       color='tab:cyan', alpha=0.2, zorder=0)\n",
    "        \n",
    "        # Add label only once\n",
    "        if col == 0:  # Only add label for the first subplot\n",
    "            ax.plot([], [], color='tab:cyan', linestyle='--', linewidth=2, \n",
    "                   alpha=0.6, label='Immediate')\n",
    "\n",
    "    ax.set_xticks(range(n_cols))\n",
    "    ax.set_xticklabels(acq_order)\n",
    "    if (d == \"cifar10-valid\"):\n",
    "        ax.set_title(d)\n",
    "    else:\n",
    "        ax.set_title(d.capitalize())\n",
    "    ax.grid(True, linestyle='--', alpha=0.4)\n",
    "    ax.set_ylabel(\"Cost-Adjusted Regret\" if col == 0 else \"\")\n",
    "    style_spines(ax)\n",
    "\n",
    "# Shared legend\n",
    "handles, labels = axes[0].get_legend_handles_labels()\n",
    "legend_hls = dict(zip(labels, handles))\n",
    "\n",
    "# Move Immediate to the end if it exists\n",
    "if 'Immediate' in legend_hls:\n",
    "    immediate_handle = legend_hls.pop('Immediate')\n",
    "    legend_hls['Immediate'] = immediate_handle\n",
    "\n",
    "fig.legend(legend_hls.values(), legend_hls.keys(),\n",
    "           loc='lower center',\n",
    "           bbox_to_anchor=(0.5, -0.15),\n",
    "           ncol=ceil(len(legend_hls) / 2),\n",
    "           fontsize=12)\n",
    "\n",
    "# fig.suptitle(f\"Cost-Adjusted Regret (λ = {str(lam)})\", fontsize=20)\n",
    "plt.tight_layout(rect=[0, 0.0, 1, 0.95])\n",
    "plt.savefig(f'../plots/BarPlot_NATS.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Plotting ===\n",
    "for d in dataset_names:\n",
    "    best_acc   = best_acc_per_dataset[d]\n",
    "    best_error = 100. - best_acc\n",
    "\n",
    "    # === 1) Test error vs estimated cumulative cost ===\n",
    "    fig, axes = plt.subplots(3, 4, figsize=(24, 18), sharey=\"row\")\n",
    "    fig.suptitle(f\"{d}\", fontsize=48)\n",
    "    for i, lam in enumerate(lambdas):\n",
    "        lam_str = f\"1e-{int(round(-np.log10(lam)))}\"\n",
    "\n",
    "        for j in range(4):\n",
    "            acq = acq_order[j]\n",
    "            fixed_acq = f\"PBGI({lam_str})\" if acq == 'PBGI' else acq\n",
    "            ax = axes[i, j]\n",
    "\n",
    "            # === Mean test error curve ===\n",
    "            cost_grid = np.linspace(0, cost_limit_per_dataset[d], 801)\n",
    "            curves = []\n",
    "            for seed in range(len(metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"])):\n",
    "                c = np.array(metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"][seed])\n",
    "                e = np.array(metrics_per_acq[d][fixed_acq][\"final test error\"][seed])\n",
    "                idx = np.argsort(c)\n",
    "                curves.append(np.interp(cost_grid, c[idx], e[idx]))\n",
    "            curves = np.vstack(curves)\n",
    "            mean_c = curves.mean(axis=0)\n",
    "            sem_c  = curves.std(axis=0) / np.sqrt(curves.shape[0])\n",
    "\n",
    "            ax.plot(cost_grid, mean_c, linestyle='--', color=color_dict[fixed_acq])\n",
    "            ax.fill_between(cost_grid, mean_c - 2 * sem_c, mean_c + 2 * sem_c, color=color_dict[fixed_acq], alpha=0.2)\n",
    "\n",
    "            # === Stopping rule overlays ===\n",
    "            for rule in stopping_rules:\n",
    "                # Only use rules from the correct lambda\n",
    "                if rule['acq_key'] != f\"PBGI({lam_str})\":\n",
    "                    continue\n",
    "                \n",
    "                # Skip Immediate rule in iteration plots\n",
    "                if rule['stp_key'] == 'Immediate':\n",
    "                    continue\n",
    "\n",
    "                stop_vals = []\n",
    "                stop_costs = []\n",
    "                for seed in range(curves.shape[0]):\n",
    "                    errs = metrics_per_acq[d][fixed_acq][\"final test error\"][seed]\n",
    "                    costs = metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"][seed]\n",
    "                    if rule['is_hindsight']:\n",
    "                        regs = np.array(errs) - best_error + lam * np.array(costs)\n",
    "                        idx = np.argmin(regs)\n",
    "                    else:\n",
    "                        idx = next((k for k in range(init, len(errs)) if rule['condition_fn'](k, seed, d)), len(errs) - 1)\n",
    "                    stop_vals.append(errs[idx])\n",
    "                    stop_costs.append(costs[idx])\n",
    "\n",
    "                mx = np.mean(stop_costs)\n",
    "                my = np.mean(stop_vals)\n",
    "                sx = np.std(stop_costs) / np.sqrt(len(stop_costs))\n",
    "                sy = np.std(stop_vals) / np.sqrt(len(stop_vals))\n",
    "\n",
    "                alpha = 1.0 if (rule['stp_key'] == 'PBGI' or rule['stp_key'] == 'Hindsight') else 0.4\n",
    "                ax.errorbar(mx, my, xerr=sx, yerr=sy,\n",
    "                            fmt=rule['marker'],\n",
    "                            markersize=16,\n",
    "                            linewidth=6,\n",
    "                            color=rule['color'],\n",
    "                            capsize=10,\n",
    "                            alpha=alpha,\n",
    "                            label=rule['label'])\n",
    "\n",
    "            if i == 0:\n",
    "                ax.set_title(acq, fontsize=36)  # only first row\n",
    "            ax.tick_params(axis='both', which='major', labelsize=28)\n",
    "            if i != 2:\n",
    "                ax.set_xticklabels([])\n",
    "                ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False)\n",
    "            ax.grid(True, linestyle='--', alpha=0.4)\n",
    "\n",
    "    fig.subplots_adjust(left=0.08, right=0.95, bottom=0.12, top=0.87, hspace=0.1)\n",
    "\n",
    "    fig.text(0.5, 0.035, 'Cumulative Runtime (Proxy)', ha='center', fontsize=36)\n",
    "    fig.text(0.015, 0.5, 'Test Error', va='center', rotation='vertical', fontsize=36)\n",
    "\n",
    "    # Two-row legend\n",
    "    from math import ceil\n",
    "    handles, labels = axes[0, 0].get_legend_handles_labels()\n",
    "    half = ceil(len(handles) / 2)\n",
    "    fig.legend(\n",
    "        handles, labels,\n",
    "        loc='upper center',\n",
    "        bbox_to_anchor=(0.5, 0.0),\n",
    "        ncol=half,\n",
    "        fontsize=32,\n",
    "        frameon=False\n",
    "    )\n",
    "\n",
    "    # === 2) Regret + λ·cost vs iteration ===\n",
    "    fig, axes = plt.subplots(3, 4, figsize=(24, 18), sharey='row')\n",
    "    fig.suptitle(f\"{d}\", fontsize=48)\n",
    "    for i, lam in enumerate(lambdas):\n",
    "        lam_str = f\"1e-{int(round(-np.log10(lam)))}\"\n",
    "\n",
    "        for j in range(4):\n",
    "            acq = acq_order[j]\n",
    "            fixed_acq = f\"PBGI({lam_str})\" if acq == 'PBGI' else acq\n",
    "            ax = axes[i, j]\n",
    "\n",
    "            regs_all = []\n",
    "            for seed in range(len(metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"])):\n",
    "                errs = np.array(metrics_per_acq[d][fixed_acq][\"final test error\"][seed])\n",
    "                costs = np.array(metrics_per_acq[d][fixed_acq][\"estimated cumulative cost\"][seed])\n",
    "                regs_all.append((errs - best_error) + lam * costs)\n",
    "            regs_all = np.vstack(regs_all)\n",
    "            mean_r = regs_all.mean(axis=0)\n",
    "            sem_r  = regs_all.std(axis=0) / np.sqrt(regs_all.shape[0])\n",
    "            iters  = np.arange(mean_r.shape[0])\n",
    "\n",
    "            ax.plot(iters, mean_r, linestyle='--', color=color_dict[fixed_acq])\n",
    "            ax.fill_between(iters, mean_r - 2 * sem_r, mean_r + 2 * sem_r, color=color_dict[fixed_acq], alpha=0.2)\n",
    "\n",
    "            # === Stopping rule overlays ===\n",
    "            for rule in stopping_rules:\n",
    "                # Only use rules from the correct lambda\n",
    "                if rule['acq_key'] != f\"PBGI({lam_str})\":\n",
    "                    continue\n",
    "                \n",
    "                # Skip Immediate rule in iteration plots\n",
    "                if rule['stp_key'] == 'Immediate':\n",
    "                    continue\n",
    "\n",
    "                stop_vals = []\n",
    "                stop_iters = []\n",
    "                for seed in range(regs_all.shape[0]):\n",
    "                    seq = regs_all[seed]\n",
    "                    if rule['is_hindsight']:\n",
    "                        idx = np.argmin(seq)\n",
    "                    else:\n",
    "                        idx = next((k for k in range(init, len(seq)) if rule['condition_fn'](k, seed, d)), len(seq) - 1)\n",
    "                    stop_vals.append(seq[idx])\n",
    "                    stop_iters.append(idx)\n",
    "\n",
    "                mx = np.mean(stop_iters)\n",
    "                my = np.mean(stop_vals)\n",
    "                sx = np.std(stop_iters) / np.sqrt(len(stop_iters))\n",
    "                sy = np.std(stop_vals) / np.sqrt(len(stop_vals))\n",
    "\n",
    "                alpha = 1.0 if (rule['stp_key'] == 'PBGI' or rule['stp_key'] == 'Hindsight') else 0.4\n",
    "                ax.errorbar(mx, my, xerr=sx, yerr=sy,\n",
    "                            fmt=rule['marker'],\n",
    "                            markersize=16,\n",
    "                            linewidth=6,\n",
    "                            color=rule['color'],\n",
    "                            capsize=10,\n",
    "                            alpha=alpha,\n",
    "                            label=rule['label'])\n",
    "\n",
    "            if i == 0:\n",
    "                ax.set_title(acq, fontsize=36)\n",
    "            ax.tick_params(axis='both', which='major', labelsize=28)\n",
    "            if i != 2:\n",
    "                ax.set_xticklabels([])\n",
    "                ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False)\n",
    "            ax.grid(True, linestyle='--', alpha=0.4)\n",
    "\n",
    "    # Rotated lambda labels on the right\n",
    "    lambda_labels = [r'$\\lambda=10^{-4}$', r'$\\lambda=10^{-5}$', r'$\\lambda=10^{-6}$']\n",
    "    for i, label in enumerate(lambda_labels):\n",
    "        ax = axes[i, -1]\n",
    "        ax.text(1.05, 0.5, label, transform=ax.transAxes, va='center', ha='left', fontsize=30, rotation=270)\n",
    "    \n",
    "    fig.subplots_adjust(left=0.08, right=0.95, bottom=0.12, top=0.87, hspace=0.1)\n",
    "\n",
    "    fig.text(0.5, 0.035, 'Iteration', ha='center', fontsize=36)\n",
    "    fig.text(0.015, 0.5, 'Cost-adjusted Regret', va='center', rotation='vertical', fontsize=36)\n",
    "\n",
    "    # Two-row legend\n",
    "    from math import ceil\n",
    "    handles, labels = axes[0, 0].get_legend_handles_labels()\n",
    "    half = ceil(len(handles) / 2)\n",
    "    fig.legend(\n",
    "        handles, labels,\n",
    "        loc='upper center',\n",
    "        bbox_to_anchor=(0.5, 0.0),\n",
    "        ncol=half,\n",
    "        fontsize=32\n",
    "    )\n",
    "    plt.savefig(f'../plots/{d}.pdf', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "automl_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
