{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "52dfe122",
   "metadata": {},
   "source": [
    "# PBGI: this‑round vs next‑round\n",
    "Illustration notebook – produces the figure showing when the classic *this‑round* rule stops too early while our *next‑round* look‑ahead avoids that. Feel free to play around with different seeds. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e624ddb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, numpy as np, matplotlib.pyplot as plt\n",
    "from botorch.models import SingleTaskGP\n",
    "from gpytorch.likelihoods import FixedNoiseGaussianLikelihood\n",
    "from gpytorch.kernels import ScaleKernel, MaternKernel\n",
    "\n",
    "torch.set_default_dtype(torch.double)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfb048d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- sample a 1‑D function we will **minimize** ------------------------------\n",
    "seed = 20         # try a few seeds until the two rules differ nicely\n",
    "torch.manual_seed(seed); np.random.seed(seed)\n",
    "granularity = 10001 \n",
    "\n",
    "noise = 1e-6\n",
    "kernel = MaternKernel(nu=2.5).double()\n",
    "kernel.lengthscale = torch.tensor([[0.1]])\n",
    "kernel = ScaleKernel(kernel).double()\n",
    "kernel.outputscale = torch.tensor([[1]])\n",
    "\n",
    "model0 = SingleTaskGP(torch.zeros(1,1), torch.zeros(1,1),\n",
    "                      likelihood=FixedNoiseGaussianLikelihood(noise=torch.tensor([noise])), covar_module=kernel)\n",
    "\n",
    "from botorch.sampling.pathwise import draw_kernel_feature_paths\n",
    "feature_path = draw_kernel_feature_paths(model0, sample_shape=torch.Size([1]))\n",
    "\n",
    "def f_true(x):            # ground‑truth objective (unknown to optimiser)\n",
    "    return feature_path(x.unsqueeze(0)).squeeze(0).detach()\n",
    "\n",
    "xs = torch.linspace(0,1,granularity).unsqueeze(-1)\n",
    "with torch.no_grad():\n",
    "    plt.figure(figsize=(7,3))\n",
    "    plt.plot(xs, f_true(xs), lw=1.2, label='f(x)')\n",
    "    plt.title('Sampled function to minimize'); plt.legend(); plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e28ef451",
   "metadata": {},
   "outputs": [],
   "source": [
    "# uniform cost helper\n",
    "def cost_uniform(x):\n",
    "    return torch.ones(x.shape[:-1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af0a9ed4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pandora_automl.acquisition.stable_gittins import StableGittinsIndex\n",
    "from pandora_automl.acquisition.ei_puc import ExpectedImprovementWithCost\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8704d081",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_this_round(max_iter=50, lam=0.01):\n",
    "    # two‑point init\n",
    "    X = torch.tensor([[0.25],[0.75]])\n",
    "    Y = f_true(X)\n",
    "    for t in range(max_iter):\n",
    "        gp = SingleTaskGP(X, Y.unsqueeze(-1),\n",
    "                          likelihood=FixedNoiseGaussianLikelihood(noise=torch.ones_like(Y)*noise))\n",
    "        acq = StableGittinsIndex(model=gp, maximize=False, lmbda=lam, cost=cost_uniform)\n",
    "        grid = torch.linspace(0,1,granularity).unsqueeze(-1)\n",
    "        vals = acq(grid.unsqueeze(1))\n",
    "        x_next = grid[vals.argmin()].unsqueeze(0)     # minimise ⇒ pick *smallest* index\n",
    "        y_next = f_true(x_next)\n",
    "        # this‑round stop: value < index of next box\n",
    "        if y_next.item() < vals.min().item():\n",
    "            X = torch.cat([X, x_next])\n",
    "            Y = torch.cat([Y, y_next])\n",
    "            break\n",
    "        X, Y = torch.cat([X, x_next]), torch.cat([Y, y_next])\n",
    "    best_idx = Y.argmin()\n",
    "    return X, Y, best_idx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a7a3779",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_next_round(max_iter=50, lam=0.01):\n",
    "    X = torch.tensor([[0.25],[0.75]])\n",
    "    Y = f_true(X)\n",
    "    for t in range(max_iter):\n",
    "        gp = SingleTaskGP(X, Y.unsqueeze(-1),\n",
    "                          likelihood=FixedNoiseGaussianLikelihood(noise=torch.ones_like(Y)*noise))\n",
    "        acq = StableGittinsIndex(model=gp, maximize=False, lmbda=lam, cost=cost_uniform)\n",
    "        grid = torch.linspace(0,1,granularity).unsqueeze(-1)\n",
    "        vals = acq(grid.unsqueeze(1))\n",
    "        # next‑round stop: no expected improvement per cost < 0\n",
    "        if (vals.min().item() > Y[-1].item()):\n",
    "            break\n",
    "        x_next = grid[vals.argmin()].unsqueeze(0)     # pick best EI per cost (minimisation)\n",
    "        y_next = f_true(x_next)\n",
    "        X, Y = torch.cat([X, x_next]), torch.cat([Y, y_next])\n",
    "    best_idx = Y.argmin()\n",
    "    return X, Y, best_idx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9cd7584",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "sns.set_style('whitegrid', {\n",
    "    'grid.linestyle': '--',\n",
    "    'grid.alpha': 0.4\n",
    "})\n",
    "\n",
    "plt.style.use('seaborn-v0_8-bright')\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = 'Times New Roman'\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams.update({\n",
    "    'font.size': 14,\n",
    "    'axes.titlesize': 14,\n",
    "    'axes.labelsize': 16,\n",
    "    'legend.fontsize': 12,\n",
    "    # 'xtick.rotation': 45,\n",
    "    'xtick.labelsize': 12,\n",
    "    'ytick.labelsize': 12,\n",
    "    'figure.autolayout': False,  # we’ll call tight_layout() explicitly\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73c995df",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "lam = 0.01\n",
    "# 1) compute true min and cost-adjusted regrets\n",
    "with torch.no_grad():\n",
    "    true_min = f_true(xs).min().item()\n",
    "\n",
    "X1, Y1, best1 = run_this_round(lam=0.01)\n",
    "X2, Y2, best2 = run_next_round(lam=0.01)\n",
    "\n",
    "costs1 = cost_uniform(X1).sum().item()\n",
    "costs2 = cost_uniform(X2).sum().item()\n",
    "\n",
    "regret1 = (Y1.min().item() - true_min) + lam * costs1\n",
    "regret2 = (Y2.min().item() - true_min) + lam * costs2\n",
    "\n",
    "# 2) side-by-side plots, left panel 4× wider than right\n",
    "fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(9, 4), gridspec_kw={'width_ratios': [4, 1.2]}) \n",
    "\n",
    "# Left: minimisation illustration\n",
    "ax0.plot(xs, f_true(xs), lw=1.2, color='grey')\n",
    "ax0.scatter(X1, f_true(X1), marker='o', c='blue', s=25, label='This-round queries', zorder=3)\n",
    "ax0.scatter(X2, f_true(X2), marker='x', c='orange', s=25, label='Next-round queries', zorder=3)\n",
    "#ax0.scatter(X1[best1], Y1[best1], c='blue', marker='o', s=40, label='best this-round', zorder=4)\n",
    "# ax0.scatter(X2[best2], Y2[best2], c='orange',  marker='o', s=40, label='best next-round', zorder=4)\n",
    "ax0.axvline(X1[best1].item(), ls=':', c='blue', lw=1.5, label='Best This-round')\n",
    "ax0.axvline(X2[best2].item(), ls='--', c='orange', lw=2, label='Best Next-round')\n",
    "# ax0.set_title('Minimisation: best value at stop')\n",
    "# ax0.set_xlabel('Query Location $x$', fontsize=14); \n",
    "ax0.set_ylabel('Objective Value  $f(x)$')\n",
    "ax0.legend(loc='lower right')\n",
    "\n",
    "# Right: cost-adjusted regret\n",
    "ax1.bar(\n",
    "    ['This-round', 'Next-round'],\n",
    "    [regret1, regret2],\n",
    "    width=0.5, \n",
    "    alpha=0.7,\n",
    "    color=['blue', 'orange'], \n",
    ")\n",
    "# ax1.set_title('Cost-adjusted simple regret')\n",
    "ax1.set_ylabel('Regret + λ·Cost')\n",
    "ax1.set_xticklabels(['This-round','Next-round'])\n",
    "ax1.set_ylim(0, max(regret1, regret2) * 1.2)\n",
    "# fig.suptitle('Best Value and Cost-Adjusted Simple Regret at Stop')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../plots/PBGI_this_round_vs_next_round.pdf\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "stopping_rule",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
