{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Synthetic experiments in the 8 dimension setting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, download data from pickle. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dill\n",
    "sweep_id = \"1ycj0lxz\"\n",
    "kernel = \"Matern52\"\n",
    "save_path = f'synthetic_results/synthetic_8d_{kernel}_{sweep_id}.pkl'\n",
    "with open(save_path, \"rb\") as f:\n",
    "    metrics_per_acq = dill.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are some helper functions for computing statistics on stopping time and cost-adjusted simple regret "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_and_std(data, num_samples, axis=0):\n",
    "    return np.mean(data, axis=axis), np.std(data, axis=axis) /np.sqrt(num_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import math\n",
    "\n",
    "def polyak_average(arr, window):\n",
    "    \"\"\"Compute Polyak average over a given window.\"\"\"\n",
    "    avg = np.full_like(arr, np.nan, dtype=float)\n",
    "    cumsum = np.nancumsum(arr, dtype=float)\n",
    "    for i in range(1, len(arr)):\n",
    "        start = max(1, i - window + 1)\n",
    "        total = cumsum[i] - (cumsum[start - 1] if start > 0 else 0.0)\n",
    "        avg[i] = total / (i - start + 1)\n",
    "    return avg\n",
    "\n",
    "def return_stopping_indices(\n",
    "    best_obs_seed: np.ndarray,\n",
    "    LogEIC_seed: np.ndarray,\n",
    "    UCB_LCB_seed: np.ndarray,\n",
    "    PRB_seed: np.ndarray,\n",
    "    RegretGap_seed: np.ndarray,\n",
    "    lmbda: float,\n",
    "    max_iter: int,\n",
    "    ucb_lcb_threshold: float = 0.01,\n",
    "    delta: float = 0.05,\n",
    "    convergence_window: int = 5,\n",
    "    gss_window: int = 5,\n",
    "    gss_frac: float = 0.01, \n",
    "    eta: float = 0.01,\n",
    "    T_init: int = 20\n",
    ") -> dict:\n",
    "    \"\"\"\n",
    "    Compute the stopping index for various stopping rules for a single seed run.\n",
    "\n",
    "    Parameters:\n",
    "    - best_obs_seed: 1D array of best observed values per iteration.\n",
    "    - LogEIC_seed: 1D array of LogEIC acquisition values.\n",
    "    - UCB_LCB_seed: 1D array of (UCB - LCB) values.\n",
    "    - PRB_seed: 1D array of PRB values.\n",
    "    - RegretGap_seed: 1D array of Regret-Gap values.\n",
    "    - lmbda: lambda value for LogEIC threshold.\n",
    "    - max_iter: maximum number of iterations (iteration cutoff).\n",
    "    - ucb_lcb_threshold: threshold for UCB-LCB stopping.\n",
    "    - delta: parameter for PRB stopping (stop when PRB > 1 - delta).\n",
    "    - convergence_window: number of consecutive iterations with no change in best observed.\n",
    "    - gss_window: number of past trials to consider for GSS rule.\n",
    "    - gss_frac: fraction of IQR to use as improvement threshold (e.g., 0.01 for 1%).\n",
    "\n",
    "    Returns:\n",
    "    - dict mapping stopping rule names to the stopping iteration index.\n",
    "    \"\"\"\n",
    "\n",
    "\n",
    "    # Recover EIC\n",
    "    EIC_seed = np.exp(LogEIC_seed)\n",
    "\n",
    "    # --- Configuration ---\n",
    "    window = 20\n",
    "    valid = {}\n",
    "\n",
    "    # --- Compute Polyak averages ---\n",
    "    polyak_avg = {\n",
    "        'LogEIC': polyak_average(EIC_seed, window),\n",
    "        'UCB': polyak_average(UCB_LCB_seed, window),\n",
    "        'SRGap-med': polyak_average(RegretGap_seed, window),\n",
    "    }\n",
    "\n",
    "    # --- LogEIC: stop when averaged log(EIC) < log(lambda) ---\n",
    "    polyak_logEIC = np.log(polyak_avg['LogEIC'])[window:]\n",
    "    valid['LogEIC'] = np.where(polyak_logEIC <= np.log(lmbda))[0] + window\n",
    "\n",
    "    # --- UCB-LCB: stop when averaged (UCB - LCB) < threshold ---\n",
    "    polyak_ucb = polyak_avg['UCB'][window:]\n",
    "    valid['UCB-LCB'] = np.where(polyak_ucb <= ucb_lcb_threshold)[0] + window\n",
    "\n",
    "    # --- SRGap-med: use Polyak-averaged RegretGap_seed ---\n",
    "    if len(RegretGap_seed) > T_init:\n",
    "        median_initial = np.nanmedian(RegretGap_seed[1:T_init+1])\n",
    "        threshold = eta * median_initial\n",
    "        polyak_srgap = polyak_avg['SRGap-med'][window:]\n",
    "        valid['SRGap-med'] = np.where(polyak_srgap < threshold)[0] + window\n",
    "    else:\n",
    "        valid['SRGap-med'] = np.array([], dtype=int)\n",
    "\n",
    "    # --- LogEIC-med: use Polyak-averaged LogEIC_seed ---\n",
    "    if len(LogEIC_seed) > T_init:\n",
    "        median_initial = np.nanmedian(EIC_seed[1:T_init+1])\n",
    "        threshold = np.log(eta) + median_initial\n",
    "        polyak_logeic_med = polyak_avg['LogEIC'][window:]\n",
    "        valid['LogEIC-med'] = np.where(polyak_logeic_med < threshold)[0] + window\n",
    "    else:\n",
    "        valid['LogEIC-med'] = np.array([], dtype=int)\n",
    "    \n",
    "    # Convergence: best observed unchanged for `convergence_window` iterations\n",
    "    if len(best_obs_seed) > convergence_window:\n",
    "        diffs = np.diff(best_obs_seed) == 0\n",
    "        window = np.ones(convergence_window - 1, dtype=int)\n",
    "        run_lengths = np.convolve(diffs.astype(int), window, mode='valid')\n",
    "        conv_candidates = np.where(run_lengths == (convergence_window - 1))[0]\n",
    "        if conv_candidates.size > 0:\n",
    "            valid['Convergence'] = np.array([conv_candidates[0] + convergence_window - 1])\n",
    "        else:\n",
    "            valid['Convergence'] = np.array([], dtype=int)\n",
    "    else:\n",
    "        valid['Convergence'] = np.array([], dtype=int)\n",
    "    \n",
    "    # GSS: stop if improvement over past `gss_window` iters < `gss_frac` of current IQR\n",
    "    valid['GSS'] = np.array([], dtype=int)\n",
    "    if len(best_obs_seed) > gss_window:\n",
    "        for i in range(gss_window, len(best_obs_seed)):\n",
    "            past_best = best_obs_seed[i - gss_window]\n",
    "            curr_best = best_obs_seed[i]\n",
    "            improvement = curr_best - past_best  # assumes maximization\n",
    "            # Compute IQR of best_obs_seed up to i (inclusive)\n",
    "            window_vals = best_obs_seed[:i+1]\n",
    "            q75, q25 = np.percentile(window_vals, [75, 25])\n",
    "            iqr = q75 - q25\n",
    "            if iqr <= 0:\n",
    "                continue\n",
    "            if improvement < gss_frac * iqr:\n",
    "                valid['GSS'] = np.array([i])\n",
    "                break\n",
    "            \n",
    "    # PRB: stop when PRB > 1 - delta\n",
    "    valid['PRB_0.1'] = np.where(PRB_seed > 1 - delta)[0]\n",
    "    # Select first valid index or fall back to max_iter - 1\n",
    "    stopping_indices = {}\n",
    "    for rule, inds in valid.items():\n",
    "        stopping_indices[rule] = int(inds[0]) if inds.size > 0 else (max_iter - 1)\n",
    "\n",
    "    return stopping_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setting the style and font of the plots, as well as some initial parameters.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "sns.set_style('whitegrid', {\n",
    "    'grid.linestyle': '--',\n",
    "    'grid.alpha': 0.4\n",
    "})\n",
    "\n",
    "plt.style.use('seaborn-v0_8-bright')\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = 'Times New Roman'\n",
    "plt.rcParams.update({\n",
    "    'font.size': 18,\n",
    "    'axes.titlesize': 20,\n",
    "    'axes.labelsize': 18,\n",
    "    'legend.fontsize': 18,\n",
    "    # 'xtick.rotation': 45,\n",
    "    'xtick.labelsize': 18,\n",
    "    'ytick.labelsize': 18,\n",
    "    'figure.autolayout': False,  # we’ll call tight_layout() explicitly\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import itertools\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "# your existing lists\n",
    "acquisition_functions = [\n",
    "    'ThompsonSampling',\n",
    "    'Stable_Gittins_Lambda_1',\n",
    "    'Stable_Gittins_Lambda_01',\n",
    "    'Stable_Gittins_Lambda_001',\n",
    "    'LogEIWithCost', \n",
    "    'UpperConfidenceBound', \n",
    "]\n",
    "acq_col_idxs = {\n",
    "    'Stable_Gittins_Lambda_1':0,\n",
    "    'Stable_Gittins_Lambda_01':0,\n",
    "    'Stable_Gittins_Lambda_001':0,\n",
    "    'LogEIWithCost':1, \n",
    "    'UpperConfidenceBound':2, \n",
    "    'ThompsonSampling':3\n",
    "}\n",
    "\n",
    "name_dict = {\n",
    "    'Stable_Gittins_Lambda_1': \"PBGI(0.1)\",\n",
    "    'Stable_Gittins_Lambda_01': \"PBGI(0.01)\",\n",
    "    'Stable_Gittins_Lambda_001':\"PBGI(0.001)\",\n",
    "    'LogEIWithCost':\"LogEIC\", \n",
    "    'UpperConfidenceBound': \"LCB\", \n",
    "    'ThompsonSampling': \"TS\", \n",
    "    'PRB_0.1': \"PRB\",\n",
    "    'LogEIC': \"PBGI/LogEIC\",\n",
    "    'LogEIC-med':\"LogEIC-med\", \n",
    "    'UCB-LCB': \"UCB-LCB\",\n",
    "    'Convergence': \"Convergence\",\n",
    "    'GSS': \"GSS\",\n",
    "    'SRGap-med':\"SRGap-med\",\n",
    "    \"Immediate\": \"Immediate\",\n",
    "    'uniform': \"Uniform Cost\",\n",
    "    'linear': \"Linear Cost\",\n",
    "    'periodic': \"Periodic Cost\" \n",
    "}\n",
    "lmbda_to_str = {\n",
    "    0.1:'1',\n",
    "    0.01:'01',\n",
    "    0.001: '001'\n",
    "}\n",
    "\n",
    "color_dict = {\n",
    "    'LogEIC':                  'tab:orange',\n",
    "    'LogEIWithCost':           'tab:blue',\n",
    "    'LogEIC-med':              'tab:blue',\n",
    "    'Stable_Gittins_Lambda_1': 'tab:orange',\n",
    "    'Stable_Gittins_Lambda_01':'tab:orange',\n",
    "    'Stable_Gittins_Lambda_001':'tab:orange',\n",
    "    'UpperConfidenceBound':    'tab:purple',\n",
    "    'UCB-LCB':                 'tab:purple',\n",
    "    'SRGap-med':               'tab:pink',\n",
    "    'ThompsonSampling':        'tab:brown',\n",
    "    'PRB_0.1':                 'tab:brown',\n",
    "    'GSS':                     'tab:olive',\n",
    "    'Convergence':             'tab:gray',\n",
    "    'Immediate':               'tab:cyan',\n",
    "    'Hindsight':               'tab:red'\n",
    "}\n",
    "marker_dict = {\n",
    "    0.1: 'o',\n",
    "    0.01: '^',\n",
    "    0.001: 's'\n",
    "}\n",
    "\n",
    "lengthscales = [0.1]\n",
    "cost_type = ['uniform', 'linear', 'periodic']\n",
    "lambda_per_acq = [0.1, 0.01, 0.001]\n",
    "lengthscales = [0.1]\n",
    "# --- stopping‐rule parameters ---\n",
    "delta = 0.05                        # for PRB: stop when P(success) > 1 – delta\n",
    "ucb_lcb_threshold = 0.01           # for UCB–LCB: stop when (UCB – LCB) < threshold\n",
    "convergence_window = 20\n",
    "gss_window = 20\n",
    "gss_frac = 0.01             \n",
    "\n",
    "# 'SRGap-med'\n",
    "stopping_rules = [\n",
    "    'LogEIC',\n",
    "    'LogEIC-med',\n",
    "    'SRGap-med',\n",
    "    'UCB-LCB',\n",
    "    'PRB_0.1',\n",
    "    'GSS',\n",
    "    'Convergence'\n",
    "]\n",
    "\n",
    "iteration_cutoff = {}\n",
    "iteration_cutoff[0.1]= 400\n",
    "iteration_cutoff[0.01]= 400\n",
    "iteration_cutoff[0.001]= 400\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plots: iteration vs cost adjusted regret. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.ticker import MaxNLocator\n",
    "# build list of column settings\n",
    "col_settings = list(itertools.product(cost_type))\n",
    "\n",
    "\n",
    "l = 0.1 \n",
    "for cost in cost_type:\n",
    "    n_rows = len(col_settings)\n",
    "    n_cols = len(acquisition_functions)-2\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols,\n",
    "        figsize=(4 * n_cols, 3 * n_rows),  # adjust per‐subplot size\n",
    "        sharex=True, sharey=False\n",
    "    )\n",
    "    # ensure axes is always shape (n_rows, n_cols)\n",
    "    axes = np.array(axes)\n",
    "    if axes.ndim == 1:\n",
    "        # if only one row, shape (n_cols,) → (1, n_cols)\n",
    "        if n_rows == 1:\n",
    "            axes = axes[np.newaxis, :]\n",
    "        # if only one col, shape (n_rows,) → (n_rows, 1)\n",
    "        else:\n",
    "            axes = axes[:, np.newaxis]\n",
    "\n",
    "    for row_idx, lmbda in enumerate(lambda_per_acq):\n",
    "        max_y = 0\n",
    "        min_y = 100\n",
    "        for col_idx, a in enumerate(acquisition_functions): \n",
    "            if a.split(\"_\")[0] == \"Stable\" and a.split(\"_\")[-1] != lmbda_to_str[lmbda]:\n",
    "                continue\n",
    "\n",
    "            ax = axes[row_idx, acq_col_idxs[a]]\n",
    "\n",
    "            regret = np.array(metrics_per_acq[(a, l, cost)]['regret'])[:, :iteration_cutoff[lmbda]]\n",
    "            cumulative_cost = np.array(metrics_per_acq[(a, l, cost)][\"cumulative cost\"])[:, :iteration_cutoff[lmbda]]\n",
    "            LogEIC_acq = np.array(metrics_per_acq[(a, l, cost)][\"LogEIC acq\"])[:, :iteration_cutoff[lmbda]]\n",
    "            UCB_LCB_acq = np.array(metrics_per_acq[(a, l, cost)][\"UCB-LCB acq\"])[:, :iteration_cutoff[lmbda]]\n",
    "            PRB_0_1_acq = np.array(metrics_per_acq[(a, l, cost)][\"PRB_0.1\"])[:, :iteration_cutoff[lmbda]]\n",
    "            RegretGap_acq = np.array(metrics_per_acq[(a, l, cost)][\"Regret-Gap acq\"])[:, :iteration_cutoff[lmbda]] \n",
    "            best_observed = np.array(metrics_per_acq[(a, l, cost)][\"best observed\"])[:, :iteration_cutoff[lmbda]]\n",
    "\n",
    "            # Compute cost-adjusted regret curve\n",
    "            cost_adjusted_regret = regret + lmbda * cumulative_cost\n",
    "\n",
    "            # Compute stopping point, which is the first time acq < best observed \n",
    "            stopping_times = {}\n",
    "            stopping_cost_adjusted_regrets = {}\n",
    "            for stopping_rule in stopping_rules: \n",
    "                stopping_cost_adjusted_regrets[stopping_rule] = []\n",
    "                stopping_times[stopping_rule] = []\n",
    "            \n",
    "            hindsight_best_times = []\n",
    "            hindsight_best_regrets = []\n",
    "\n",
    "            n_seeds = regret.shape[0]\n",
    "\n",
    "            for seed_idx in range(n_seeds):\n",
    "                best_obs_seed = best_observed[seed_idx]\n",
    "                stopping_indices = return_stopping_indices(\n",
    "                    best_obs_seed=best_obs_seed,\n",
    "                    LogEIC_seed=LogEIC_acq[seed_idx],\n",
    "                    UCB_LCB_seed=UCB_LCB_acq[seed_idx],\n",
    "                    PRB_seed=PRB_0_1_acq[seed_idx],\n",
    "                    RegretGap_seed=RegretGap_acq[seed_idx],\n",
    "                    lmbda=lmbda,\n",
    "                    max_iter=iteration_cutoff[lmbda],\n",
    "                    ucb_lcb_threshold=ucb_lcb_threshold,\n",
    "                    delta=delta,\n",
    "                    convergence_window=convergence_window,\n",
    "                    gss_window=gss_window,\n",
    "                    gss_frac=gss_frac\n",
    "                )\n",
    "\n",
    "                for stopping_rule in stopping_rules:\n",
    "                    i_stop = stopping_indices[stopping_rule]\n",
    "                    stopping_times[stopping_rule].append(stopping_indices[stopping_rule])\n",
    "                    cost_adj_regret_at_stop = cost_adjusted_regret[seed_idx, i_stop]\n",
    "                    stopping_cost_adjusted_regrets[stopping_rule].append(cost_adj_regret_at_stop)\n",
    "\n",
    "                i_best = np.argmin(cost_adjusted_regret[seed_idx])\n",
    "                hindsight_best_times.append(i_best)\n",
    "                hindsight_best_regrets.append(cost_adjusted_regret[seed_idx, i_best])\n",
    "        \n",
    "            # Convert to arrays\n",
    "            for stopping_rule in stopping_rules:\n",
    "                stopping_times[stopping_rule] = np.array(stopping_times[stopping_rule])\n",
    "                stopping_cost_adjusted_regrets[stopping_rule] = np.array(stopping_cost_adjusted_regrets[stopping_rule])\n",
    "            \n",
    "            # Compute cost-adjusted regret curve mean and standard error over seeds\n",
    "            mean_curve, stderr_curve = mean_and_std(cost_adjusted_regret, cost_adjusted_regret.shape[0])\n",
    "\n",
    "            # Plot mean curve and fill error bars\n",
    "            x = np.arange(len(mean_curve))\n",
    "            # label='Fix iteration mean regret', \n",
    "            ax.plot(mean_curve, linestyle='--', color=color_dict[a], zorder=1)\n",
    "            ax.fill_between(np.arange(len(mean_curve)), \n",
    "                            mean_curve - 2*stderr_curve,\n",
    "                            mean_curve + 2*stderr_curve, \n",
    "                            color=color_dict[a],\n",
    "                            alpha=0.2, zorder=1)\n",
    "            \n",
    "            \n",
    "            # plot stopping markers\n",
    "            for rule in stopping_rules:\n",
    "                mt,ms = mean_and_std(stopping_times[rule], n_seeds) \n",
    "                et,es = mean_and_std(stopping_cost_adjusted_regrets[rule], n_seeds)\n",
    "                if rule == 'LogEIC':\n",
    "                    zorder=3\n",
    "                else:\n",
    "                    zorder=2\n",
    "                ax.errorbar(mt, et, xerr=ms, yerr=es,\n",
    "                            fmt='s', color=color_dict[rule], ms=8, capsize=3, label=name_dict[rule], zorder=zorder)\n",
    "            \n",
    "            mean_hindsight_time, stderr_hindsight_time = mean_and_std(hindsight_best_times, n_seeds)\n",
    "            mean_hindsight_regret, stderr_hindsight_regret = mean_and_std(hindsight_best_regrets, n_seeds) \n",
    "            \n",
    "            ax.errorbar(mean_hindsight_time,\n",
    "                                mean_hindsight_regret,\n",
    "                                xerr=2*stderr_hindsight_time,\n",
    "                                yerr=2*stderr_hindsight_regret,\n",
    "                                fmt='x', color=color_dict['Hindsight'], ms=8, capsize=3, label='Hindsight')\n",
    "\n",
    "            # title only on top row, y‐label only on first column\n",
    "            if (row_idx == 0):\n",
    "                ax.set_title(f'{name_dict[a]}')\n",
    "            elif (row_idx == 2): \n",
    "                ax.set_xlabel(\"Iteration\")\n",
    "            \n",
    "            ax.grid(True, linestyle='--', alpha=0.4)\n",
    "            \n",
    "            for rule in stopping_rules:\n",
    "                max_y = max(max_y, mean_curve.max())\n",
    "                min_y = min(min_y, min(mean_curve.min(), mean_hindsight_regret.min()))\n",
    "        print(f\"limits: {min_y}, {max_y}\")\n",
    "        for col_idx in range(4):\n",
    "            axes[row_idx, col_idx].set_ylim(min_y*0.8, max_y*1.1)\n",
    "            x_ticks = [100, 200, 300, 400, 500]\n",
    "            axes[row_idx, col_idx].set_xticks(x_ticks)\n",
    "            axes[row_idx, col_idx].yaxis.set_major_locator(MaxNLocator(nbins=3))\n",
    "    \n",
    "    # Rotated lambda labels on the right\n",
    "    lambda_labels = [r'$\\lambda=10^{-1}$', r'$\\lambda=10^{-2}$', r'$\\lambda=10^{-3}$']\n",
    "    for i, label in enumerate(lambda_labels):\n",
    "        ax = axes[i, -1]\n",
    "        ax.text(1.05, 0.5, label, transform=ax.transAxes, va='center', ha='left', rotation=270)\n",
    "    \n",
    "    fig.text(0.07, 0.5, 'Cost-adjusted Regret', va='center', rotation='vertical', fontsize=22)\n",
    "    \n",
    "    # one legend for entire figure\n",
    "    handles, labels = axes[0,0].get_legend_handles_labels()\n",
    "    fig.legend(handles, labels,\n",
    "               loc='lower center',\n",
    "               bbox_to_anchor=(0.5, -0.08),\n",
    "               ncol=4)\n",
    "\n",
    "    fig.suptitle(f'{kernel} 8D Iteration vs Cost-Adjusted Regret ({cost} cost)')\n",
    "    # plt.tight_layout(rect=[0, 0.0, 1, 0.95])\n",
    "    plt.xlim(-1, iteration_cutoff[lmbda])\n",
    "    plt.savefig(f'../plots/{kernel}_8D_{cost}.pdf', bbox_inches='tight')\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plots: only showing cost-adjusted regret "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style('whitegrid', {\n",
    "    'grid.linestyle': '--',\n",
    "    'grid.alpha': 0.4\n",
    "})\n",
    "\n",
    "plt.style.use('seaborn-v0_8-bright')\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = 'Times New Roman'\n",
    "plt.rcParams.update({\n",
    "    'font.size': 12,\n",
    "    'axes.titlesize': 14,\n",
    "    'axes.labelsize': 12,\n",
    "    'legend.fontsize': 10,\n",
    "    # 'xtick.rotation': 45,\n",
    "    'xtick.labelsize': 10,\n",
    "    'ytick.labelsize': 10,\n",
    "    'figure.autolayout': False,  # we'll call tight_layout() explicitly\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.ticker import ScalarFormatter\n",
    "l = 0.1\n",
    "enable_immediate = True\n",
    "if (enable_immediate):\n",
    "    stopping_rules.append('Immediate')\n",
    "    \n",
    "def style_spines(ax, color='black', linewidth=1):\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "        spine.set_edgecolor(color)\n",
    "        spine.set_linewidth(linewidth)\n",
    "\n",
    "\n",
    "# build list of column settings\n",
    "col_settings = list(itertools.product(cost_type))\n",
    "n_cols = len(col_settings) \n",
    "n_rows = 1\n",
    "\n",
    "for lmbda in lambda_per_acq:\n",
    "    fig, axes = plt.subplots(\n",
    "        n_rows, n_cols, \n",
    "        figsize=(3 * n_cols, 3.6 * 1),   # adjust per‐subplot size\n",
    "        sharex=False, sharey=False\n",
    "    )\n",
    "\n",
    "    axes = axes.T\n",
    "    # ensure axes is always shape (n_rows, n_cols)\n",
    "    axes = np.array(axes)\n",
    "    if axes.ndim == 1:\n",
    "        # if only one row, shape (n_cols,) → (1, n_cols)\n",
    "        if n_rows == 1:\n",
    "            axes = axes[np.newaxis, :]\n",
    "        # if only one col, shape (n_rows,) → (n_rows, 1)\n",
    "        else:\n",
    "            axes = axes[:, np.newaxis]\n",
    "  \n",
    "    for col_idx, cost in enumerate(cost_type):\n",
    "        max_y = 0\n",
    "        min_y = 100\n",
    "        x_labels = [\"\"] * (len(acquisition_functions)-2)\n",
    "        for a in acquisition_functions: # col_idx, a in enumerate(acquisition_functions):\n",
    "            # print(a.split(\"_\")[-1])\n",
    "            if a.split(\"_\")[0] == \"Stable\" and a.split(\"_\")[-1] != lmbda_to_str[lmbda]:\n",
    "                continue\n",
    "\n",
    "            ax = axes[0, col_idx] # acq_col_idxs[a]]\n",
    "            acq_function_idx = acq_col_idxs[a]\n",
    "            x_labels[acq_function_idx] = name_dict[a]\n",
    "\n",
    "            metric_keys = [\n",
    "                \"cumulative cost\",\n",
    "                \"best observed\",\n",
    "                \"regret\", \n",
    "                \"StablePBGI(0.1) acq\",\n",
    "                \"StablePBGI(0.01) acq\",\n",
    "                \"StablePBGI(0.001) acq\",\n",
    "                \"LogEIC acq\", \n",
    "                \"UCB-LCB acq\", \n",
    "                \"Regret-Gap acq\", \n",
    "                \"PRB_0.1\"\n",
    "            ]\n",
    "\n",
    "            # print(np.array(metrics_per_acq[(a, l, cost)]))\n",
    "            regret = np.array(metrics_per_acq[(a, l, cost)]['regret'])[:, :iteration_cutoff[lmbda]]\n",
    "            cumulative_cost = np.array(metrics_per_acq[(a, l, cost)][\"cumulative cost\"])[:, :iteration_cutoff[lmbda]]\n",
    "            LogEIC_acq = np.array(metrics_per_acq[(a, l, cost)][\"LogEIC acq\"])[:, :iteration_cutoff[lmbda]]\n",
    "            UCB_LCB_acq = np.array(metrics_per_acq[(a, l, cost)][\"UCB-LCB acq\"])[:, :iteration_cutoff[lmbda]]\n",
    "            PRB_0_1_acq = np.array(metrics_per_acq[(a, l, cost)][\"PRB_0.1\"])[:, :iteration_cutoff[lmbda]]\n",
    "            RegretGap_acq = np.array(metrics_per_acq[(a, l, cost)][\"Regret-Gap acq\"])[:, :iteration_cutoff[lmbda]] \n",
    "            # current_best = np.array(current_best_per_acq[(a, l)])[:, :iteration_cutoff]\n",
    "            best_observed = np.array(metrics_per_acq[(a, l, cost)][\"best observed\"])[:, :iteration_cutoff[lmbda]]\n",
    "\n",
    "            # Compute cost-adjusted regret curve\n",
    "            cost_adjusted_regret = regret + lmbda * cumulative_cost\n",
    "\n",
    "            # === 1. Compute stopping point, which is the first time acq < best observed \n",
    "            stopping_times = {}\n",
    "            stopping_cost_adjusted_regrets = {}\n",
    "            for stopping_rule in stopping_rules: \n",
    "                stopping_cost_adjusted_regrets[stopping_rule] = []\n",
    "                stopping_times[stopping_rule] = []\n",
    "            \n",
    "            hindsight_best_times = []\n",
    "            hindsight_best_regrets = []\n",
    "\n",
    "            n_seeds = regret.shape[0]\n",
    "\n",
    "            for seed_idx in range(n_seeds):\n",
    "                best_obs_seed = best_observed[seed_idx]\n",
    "                stopping_indices = return_stopping_indices(\n",
    "                    best_obs_seed=best_obs_seed,\n",
    "                    LogEIC_seed=LogEIC_acq[seed_idx],\n",
    "                    UCB_LCB_seed=UCB_LCB_acq[seed_idx],\n",
    "                    PRB_seed=PRB_0_1_acq[seed_idx],\n",
    "                    RegretGap_seed=RegretGap_acq[seed_idx],\n",
    "                    lmbda=lmbda,\n",
    "                    max_iter=iteration_cutoff[lmbda],\n",
    "                    ucb_lcb_threshold=ucb_lcb_threshold,\n",
    "                    delta=delta,\n",
    "                    convergence_window=convergence_window,\n",
    "                    gss_window=gss_window,\n",
    "                    gss_frac=gss_frac\n",
    "                )\n",
    "\n",
    "                # Only include if stopping condition met\n",
    "                for stopping_rule in stopping_rules:\n",
    "                    if stopping_rule == 'Immediate':\n",
    "                        i_stop = 0\n",
    "                        stopping_times[stopping_rule].append(0)\n",
    "                    else:\n",
    "                        i_stop = stopping_indices[stopping_rule]\n",
    "                        stopping_times[stopping_rule].append(stopping_indices[stopping_rule])\n",
    "                    cost_adj_regret_at_stop = cost_adjusted_regret[seed_idx, i_stop]\n",
    "                    stopping_cost_adjusted_regrets[stopping_rule].append(cost_adj_regret_at_stop)\n",
    "\n",
    "                i_best = np.argmin(cost_adjusted_regret[seed_idx])\n",
    "                hindsight_best_times.append(i_best)\n",
    "                hindsight_best_regrets.append(cost_adjusted_regret[seed_idx, i_best])\n",
    "        \n",
    "            # Convert to arrays\n",
    "            for stopping_rule in stopping_rules:\n",
    "                stopping_times[stopping_rule] = np.array(stopping_times[stopping_rule])\n",
    "                stopping_cost_adjusted_regrets[stopping_rule] = np.array(stopping_cost_adjusted_regrets[stopping_rule])\n",
    "\n",
    "            # Plot cost adjusted simple regret of different stopping rules, on the same axis\n",
    "            for i, rule in enumerate(stopping_rules):\n",
    "                mt,ms = mean_and_std(stopping_times[rule], n_seeds) \n",
    "                et,es = mean_and_std(stopping_cost_adjusted_regrets[rule], n_seeds)\n",
    "                max_y = max(max_y, et + es)\n",
    "                min_y = min(min_y, et - es)\n",
    "                if rule == 'LogEIC':\n",
    "                    zorder=3\n",
    "                    alpha = 0.8\n",
    "                else:\n",
    "                    zorder=2\n",
    "                    alpha = 0.3\n",
    "                if rule == 'Immediate':\n",
    "                    continue \n",
    "                    \n",
    "                else:\n",
    "\n",
    "                    ax.errorbar(acq_function_idx, et, yerr=2*es,markersize=8,\n",
    "                                linewidth=2,\n",
    "                                fmt='s', \n",
    "                                color=color_dict[rule], \n",
    "                                capsize=0, \n",
    "                                label=name_dict[rule], \n",
    "                                zorder=zorder, \n",
    "                                alpha=alpha)\n",
    "            \n",
    "            mean_hindsight_time, stderr_hindsight_time = mean_and_std(hindsight_best_times, n_seeds)\n",
    "            mean_hindsight_regret, stderr_hindsight_regret = mean_and_std(hindsight_best_regrets, n_seeds) \n",
    "            \n",
    "            ax.errorbar(acq_function_idx,mean_hindsight_regret,yerr=2*stderr_hindsight_regret,\n",
    "                        fmt='^', \n",
    "                        markersize=8, \n",
    "                        linewidth=2,\n",
    "                        color=color_dict['Hindsight'], \n",
    "                        capsize=0, \n",
    "                        label='Hindsight')\n",
    "                        \n",
    "            mean_immediate,stderr_immediate = mean_and_std(stopping_cost_adjusted_regrets[\"Immediate\"], n_seeds) \n",
    "            if (enable_immediate):\n",
    "                # Plot Immediate as horizontal line with shaded error bars\n",
    "                ax.axhline(y=mean_immediate, xmin=0, xmax=1, \n",
    "                        color='tab:cyan', linestyle='--', linewidth=2, alpha=0.4, zorder=1)\n",
    "        \n",
    "                # Add shaded error region\n",
    "                ax.fill_between([0, n_cols], \n",
    "                                    [mean_immediate - 2*stderr_immediate, mean_immediate - 2*stderr_immediate],\n",
    "                                    [mean_immediate + 2*stderr_immediate, mean_immediate + 2*stderr_immediate],\n",
    "                                    color='tab:cyan', alpha=0.04, zorder=0) \n",
    "            \n",
    "        print(f\"limits: {0}, {max_y}\")\n",
    "\n",
    "        # Add label only once\n",
    "        ax.plot([], [], color='tab:cyan', linestyle='--', linewidth=2, \n",
    "                alpha=0.6, label='Immediate')\n",
    "\n",
    "        ax.set_title(name_dict[cost], fontsize=14)\n",
    "        ax.set_ylim(min_y*0.6, max_y*1.1)\n",
    "        plain_formatter = ScalarFormatter()\n",
    "        plain_formatter.set_scientific(False)\n",
    "        plain_formatter.set_useOffset(False)\n",
    "        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))\n",
    "        # ax.set_xlim(-0.9, n_cols - 0.1)\n",
    "        x_labels = [name.split('(')[0] for name in x_labels]\n",
    "        print(x_labels)\n",
    "        ax.set_xticks(range(len(x_labels)))\n",
    "        ax.set_xticklabels(x_labels)\n",
    "        # ax.set_xticks([x + 0.5 for x in range(-1, n_cols)], minor=True)\n",
    "        ax.grid(which='major', linestyle='--', alpha=0.4)\n",
    "        ax.grid(which='minor', visible=False)\n",
    "        ax.set_ylabel(\"Cost-Adjusted Regret\" if col_idx == 0 else \"\")\n",
    "        \n",
    "        style_spines(ax)\n",
    "\n",
    "\n",
    "    # one legend for entire figure (optional)\n",
    "    handles, labels = axes[0,0].get_legend_handles_labels()\n",
    "    legend_hls = dict(zip(labels, handles))\n",
    "    if 'Immediate' in legend_hls:\n",
    "        immediate_handle = legend_hls.pop('Immediate')\n",
    "        legend_hls['Immediate'] = immediate_handle\n",
    "    \n",
    "    fig.legend(legend_hls.values(), legend_hls.keys(), # handles, labels,\n",
    "               loc='lower center',\n",
    "               bbox_to_anchor=(0.5, -0.1),\n",
    "               ncol=5,\n",
    "               fontsize=12)\n",
    "    \n",
    "    plt.tight_layout(rect=[0, 0.05, 1, 0.95])\n",
    "    plt.savefig(f'../plots/BarPlot_{kernel}_8D_{lmbda}.pdf', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pandora",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
