{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "from pprint import pprint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import importlib.util\n",
    "\n",
    "# Construct the path relative to the current working directory\n",
    "api_path = os.path.join(os.getcwd(), \"LCBench\", \"api.py\")\n",
    "\n",
    "print(\"Resolved API path:\", api_path)\n",
    "\n",
    "spec = importlib.util.spec_from_file_location(\"LCBench.api\", api_path)\n",
    "api = importlib.util.module_from_spec(spec)\n",
    "spec.loader.exec_module(api)\n",
    "\n",
    "Benchmark = api.Benchmark\n",
    "\n",
    "os.makedirs(\"LCBench/cached\", exist_ok=True)\n",
    "bench_dir = \"LCBench/cached/six_datasets_lw.json\"\n",
    "bench = Benchmark(bench_dir, cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# Set default tensor type to float64\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_config(config):\n",
    "    # Convert each value to a torch tensor (ensuring float type for calculations)\n",
    "    batch = torch.tensor(config[\"batch_size\"])\n",
    "    lr = torch.tensor(config[\"learning_rate\"])\n",
    "    units = torch.tensor(config[\"max_units\"])\n",
    "    momentum = torch.tensor(config[\"momentum\"])\n",
    "    weight_decay = torch.tensor(config[\"weight_decay\"])\n",
    "    layers = torch.tensor(float(config[\"num_layers\"]))\n",
    "    dropout = torch.tensor(config[\"max_dropout\"])\n",
    "    \n",
    "    # For log-scaled parameters: batch size, learning rate, and max units.\n",
    "    batch_norm = (torch.log(batch) - torch.log(torch.tensor(16.0))) / (torch.log(torch.tensor(512.0)) - torch.log(torch.tensor(16.0)))\n",
    "    lr_norm = (torch.log(lr) - torch.log(torch.tensor(1e-4))) / (torch.log(torch.tensor(1e-1)) - torch.log(torch.tensor(1e-4)))\n",
    "    units_norm = (torch.log(units) - torch.log(torch.tensor(64.0))) / (torch.log(torch.tensor(1024.0)) - torch.log(torch.tensor(64.0)))\n",
    "    \n",
    "    # For linearly scaled parameters.\n",
    "    momentum_norm = (momentum - 0.1) / (0.99 - 0.1)\n",
    "    weight_decay_norm = (weight_decay - 1e-5) / (1e-1 - 1e-5)\n",
    "    layers_norm = (layers - 1) / (4 - 1)\n",
    "    \n",
    "    # Dropout is already between 0 and 1.\n",
    "    dropout_norm = dropout\n",
    "\n",
    "    # Combine into a 7-dimensional tensor.\n",
    "    normalized_vector = torch.stack([\n",
    "        batch_norm, \n",
    "        lr_norm, \n",
    "        momentum_norm, \n",
    "        weight_decay_norm, \n",
    "        layers_norm, \n",
    "        units_norm, \n",
    "        dropout_norm\n",
    "    ])\n",
    "    \n",
    "    return normalized_vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = []\n",
    "all_y = []\n",
    "all_c = []\n",
    "x2id = {}\n",
    "dataset_name = \"higgs\"\n",
    "for config_id in bench.data[dataset_name].keys():\n",
    "    config = bench.query(dataset_name, \"config\", config_id)\n",
    "    x = normalize_config(config)\n",
    "    all_x.append(x)\n",
    "    val_ce = bench.query(dataset_name, \"final_val_cross_entropy\", config_id)\n",
    "    all_y.append(val_ce)\n",
    "    runtime = bench.query(dataset_name, \"time\", config_id)[-1]\n",
    "    all_c.append(runtime)\n",
    "    x2id[x.numpy().tobytes()] = config_id\n",
    "\n",
    "all_x = torch.stack(all_x)\n",
    "all_y = torch.tensor(all_y).unsqueeze(1)\n",
    "all_c = torch.tensor(all_c).unsqueeze(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.4299, 0.4204, 0.1272,  ..., 0.6667, 0.5487, 0.0259],\n",
       "         [0.9672, 0.6977, 0.0720,  ..., 1.0000, 0.9729, 0.5472],\n",
       "         [0.8919, 0.1077, 0.3272,  ..., 0.0000, 0.8208, 0.3320],\n",
       "         ...,\n",
       "         [0.6750, 0.8598, 0.4454,  ..., 0.6667, 0.4707, 0.3635],\n",
       "         [0.9691, 0.3290, 0.0093,  ..., 0.3333, 0.8684, 0.0437],\n",
       "         [0.3666, 0.9906, 0.2041,  ..., 1.0000, 0.6681, 0.4045]]),\n",
       " tensor([[0.6380],\n",
       "         [0.6931],\n",
       "         [0.7014],\n",
       "         ...,\n",
       "         [0.6090],\n",
       "         [0.6511],\n",
       "         [0.6931]]),\n",
       " tensor([[215.9746],\n",
       "         [876.6520],\n",
       "         [126.5395],\n",
       "         ...,\n",
       "         [186.1553],\n",
       "         [165.5783],\n",
       "         [877.6618]]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_x, all_y, all_c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Initial config id: [1542, 67, 876, 414, 26, 1335, 620, 1924, 950, 1113, 1378, 1014, 1210, 954, 231, 1572]\n"
     ]
    }
   ],
   "source": [
    "from pandora_automl.utils import fit_gp_model\n",
    "seed = 42\n",
    "dim = 7\n",
    "output_standardize = True\n",
    "torch.manual_seed(seed)\n",
    "init_config_id = torch.randint(low=0, high=2000, size=(2*(dim+1),))\n",
    "config_id_history = init_config_id.tolist()\n",
    "print(f\"  Initial config id: {config_id_history}\")\n",
    "x = all_x[init_config_id]\n",
    "y = all_y[init_config_id]\n",
    "c = all_c[init_config_id]\n",
    "\n",
    "model = fit_gp_model(X=x, objective_X=y, cost_X=c, unknown_cost=True, output_standardize=output_standardize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.6447, 5.5940],\n",
       "         [0.6567, 5.4204],\n",
       "         [0.6523, 5.0305],\n",
       "         ...,\n",
       "         [0.6336, 5.4245],\n",
       "         [0.6498, 5.2734],\n",
       "         [0.6641, 5.6044]], grad_fn=<CloneBackward0>),\n",
       " tensor([[0.0006, 0.1223],\n",
       "         [0.0006, 0.1398],\n",
       "         [0.0005, 0.0986],\n",
       "         ...,\n",
       "         [0.0005, 0.1192],\n",
       "         [0.0006, 0.1262],\n",
       "         [0.0006, 0.1172]], grad_fn=<ClampMinBackward0>))"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "posterior = model.posterior(all_x)\n",
    "means = posterior.mean  # (b) x 2\n",
    "vars = posterior.variance.clamp_min(1e-6)  # (b) x 2\n",
    "stds = vars.sqrt()\n",
    "mean_obj = means[..., 0]\n",
    "std_obj = stds[..., 0]\n",
    "means, vars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([285.7591, 242.3268, 160.7396,  ..., 240.8298, 207.7808, 288.0104],\n",
       "       grad_fn=<ExpBackward0>)"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mgf = torch.exp(means[..., 1] + 0.5 * vars[..., 1])\n",
    "mgf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import norm\n",
    "\n",
    "def expected_improvement_minimize(g, mu, sigma):\n",
    "    z = (g - mu) / sigma\n",
    "    return (g - mu) * norm.cdf(z) + sigma * norm.pdf(z)\n",
    "\n",
    "def batch_gittins_indices(mu, sigma, cost, tol=1e-6, max_iter=100):\n",
    "    mu = np.asarray(mu)\n",
    "    sigma = np.asarray(sigma)\n",
    "    cost = np.asarray(cost)\n",
    "\n",
    "    l = mu - 10 * sigma\n",
    "    h = mu + 10 * sigma\n",
    "    m = (l + h) / 2\n",
    "\n",
    "    for _ in range(max_iter):\n",
    "        ei = expected_improvement_minimize(m, mu, sigma)\n",
    "        sgn = np.sign(ei - cost)\n",
    "        l = np.where(sgn <= 0, m, l)\n",
    "        h = np.where(sgn >= 0, m, h)\n",
    "        m = (l + h) / 2\n",
    "        if np.max(np.abs(ei - cost)) < tol:\n",
    "            break\n",
    "    return m\n",
    "\n",
    "def prob_f_leq_g(mu, sigma, g):\n",
    "    z = (g - np.asarray(mu)) / np.asarray(sigma)\n",
    "    return norm.cdf(z)\n",
    "\n",
    "def compute_expected_cumulative_costs(gittins, probs, costs):\n",
    "    idx = np.argsort(gittins)\n",
    "    sorted_probs = probs[idx]\n",
    "    sorted_costs = costs[idx]\n",
    "    E = 0.0\n",
    "    for p, c in zip(reversed(sorted_probs), reversed(sorted_costs)):\n",
    "        E = c + (1 - p) * E\n",
    "    return E\n",
    "\n",
    "def cost_scaling_objective(scaling, exp_budget, means, stds, costs):\n",
    "    # Scale the costs\n",
    "    scaled_costs = scaling * costs\n",
    "    # Compute Gittins indices with the scaled costs\n",
    "    gittins = batch_gittins_indices(means, stds, scaled_costs)\n",
    "    # Compute stopping probabilities\n",
    "    stopping_probs = prob_f_leq_g(means, stds, gittins)\n",
    "    # Compute the expected cumulative cost\n",
    "    exp_cost = compute_expected_cumulative_costs(gittins, stopping_probs, costs)\n",
    "    return exp_cost - exp_budget\n",
    "\n",
    "def find_cost_scaling_custom(exp_budget, means, stds, costs, lower=0.0, upper=1.0, fixed_iters=100):\n",
    "    \"\"\"\n",
    "    Custom bisection search that stops after a fixed number of iterations or\n",
    "    when the absolute objective function value is less than ytol.\n",
    "\n",
    "    Parameters:\n",
    "      exp_budget : target expected cumulative cost.\n",
    "      means, stds : arrays for the Gaussian parameters.\n",
    "      costs     : array of evaluation costs.\n",
    "      lower, upper : initial bracketing interval for the scaling factor.\n",
    "      fixed_iters: fixed number of iterations to run.\n",
    "      ytol       : tolerance on the function value f(mid).\n",
    "\n",
    "    Returns:\n",
    "      A scaling factor such that the cost_scaling_objective is nearly zero.\n",
    "    \"\"\"\n",
    "    # First, update the upper bound to ensure f(upper) < 0.\n",
    "    while cost_scaling_objective(upper, exp_budget, means, stds, costs) > 0:\n",
    "        upper *= 2.0\n",
    "\n",
    "    a, b = lower, upper\n",
    "    for _ in range(fixed_iters):\n",
    "        mid = (a + b) / 2.0\n",
    "        f_mid = cost_scaling_objective(mid, exp_budget, means, stds, costs)\n",
    "        # Check which subinterval contains the root\n",
    "        if cost_scaling_objective(a, exp_budget, means, stds, costs) * f_mid < 0:\n",
    "            b = mid\n",
    "        else:\n",
    "            a = mid\n",
    "\n",
    "    return (a + b) / 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computed cost scaling: 2.029167158567526e-07\n",
      "Expected cumulative costs: 36000.02671412786\n"
     ]
    }
   ],
   "source": [
    "exp_budget = 36000\n",
    "\n",
    "cost_scaling = find_cost_scaling_custom(exp_budget, mean_obj.detach().numpy(), std_obj.detach().numpy(), mgf.detach().numpy())\n",
    "print(\"Computed cost scaling:\", cost_scaling)\n",
    "\n",
    "gittins = batch_gittins_indices(mean_obj.detach().numpy(), std_obj.detach().numpy(), cost_scaling * mgf.detach().numpy())\n",
    "stopping_probs = prob_f_leq_g(mean_obj.detach().numpy(), std_obj.detach().numpy(), gittins)\n",
    "exp_cum_costs = compute_expected_cumulative_costs(gittins, stopping_probs, mgf.detach().numpy())\n",
    "print(\"Expected cumulative costs:\", exp_cum_costs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = []\n",
    "all_y = []\n",
    "all_c = []\n",
    "dataset_name = \"higgs\"\n",
    "for config_id in bench.data[dataset_name].keys():\n",
    "    config = bench.query(dataset_name, \"config\", config_id)\n",
    "    all_x.append(normalize_config(config))\n",
    "    val_ce = bench.query(dataset_name, \"final_val_cross_entropy\", config_id)\n",
    "    all_y.append(val_ce)\n",
    "    cost = bench.query(dataset_name, \"model_parameters\", config_id)\n",
    "    all_c.append(cost)\n",
    "\n",
    "all_x = torch.stack(all_x)\n",
    "all_y = torch.tensor(all_y).unsqueeze(1)\n",
    "all_c = torch.tensor(all_c).unsqueeze(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Initial config id: [1542, 67, 876, 414, 26, 1335, 620, 1924, 950, 1113, 1378, 1014, 1210, 954, 231, 1572]\n"
     ]
    }
   ],
   "source": [
    "from pandora_automl.utils import fit_gp_model\n",
    "seed = 42\n",
    "dim = 7\n",
    "output_standardize = True\n",
    "torch.manual_seed(seed)\n",
    "init_config_id = torch.randint(low=0, high=2000, size=(2*(dim+1),))\n",
    "config_id_history = init_config_id.tolist()\n",
    "print(f\"  Initial config id: {config_id_history}\")\n",
    "x = all_x[init_config_id]\n",
    "y = all_y[init_config_id]\n",
    "c = all_c[init_config_id]\n",
    "\n",
    "model = fit_gp_model(X=x, objective_X=y, output_standardize=output_standardize)\n",
    "\n",
    "posterior = model.posterior(all_x)\n",
    "means = posterior.mean  # (b) x 2\n",
    "vars = posterior.variance.clamp_min(1e-6)  # (b) x 2\n",
    "stds = vars.sqrt()\n",
    "\n",
    "means = means[..., 0].detach().numpy()\n",
    "stds = stds[..., 0].detach().numpy()\n",
    "costs = all_c[..., 0].detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computed cost scaling: 2.010496568154024e-07\n",
      "Expected cumulative costs: 35999.999999999985\n"
     ]
    }
   ],
   "source": [
    "exp_budget = 36000\n",
    "\n",
    "cost_scaling = find_cost_scaling_custom(exp_budget, means, stds, costs)\n",
    "print(\"Computed cost scaling:\", cost_scaling)\n",
    "\n",
    "gittins = batch_gittins_indices(means, stds, cost_scaling * costs)\n",
    "stopping_probs = prob_f_leq_g(means, stds, gittins)\n",
    "exp_cum_costs = compute_expected_cumulative_costs(gittins, stopping_probs, costs)\n",
    "print(\"Expected cumulative costs:\", exp_cum_costs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pandora_automl.utils import fit_gp_model\n",
    "import numpy as np\n",
    "import math\n",
    "from botorch.acquisition import LogExpectedImprovement\n",
    "from pandora_automl.acquisition.log_ei_puc import LogExpectedImprovementWithCost\n",
    "from botorch.acquisition import UpperConfidenceBound\n",
    "from pandora_automl.acquisition.lcb import LowerConfidenceBound\n",
    "from pandora_automl.acquisition.gittins import GittinsIndex\n",
    "from pandora_automl.acquisition.stable_gittins import StableGittinsIndex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Initial config id: [1411, 814, 767, 1885, 6, 70, 1792, 233, 88, 404, 1879, 750, 1211, 1312, 318, 288]\n",
      "Iteration 1:\n",
      "  Selected config_id: 290\n",
      "  Acquisition value: 0.6067\n",
      "  Objective (final_val_cross_entropy): 0.5870\n",
      "  Cost (time): 340.8998\n",
      "  Current best observed: 0.5994\n",
      "\n",
      "Iteration 2:\n",
      "  Selected config_id: 1239\n",
      "  Acquisition value: 0.5942\n",
      "  Objective (final_val_cross_entropy): 0.6277\n",
      "  Cost (time): 299.1622\n",
      "  Current best observed: 0.5870\n",
      "\n",
      "Iteration 3:\n",
      "  Selected config_id: 1318\n",
      "  Acquisition value: 0.5957\n",
      "  Objective (final_val_cross_entropy): 0.5906\n",
      "  Cost (time): 194.1711\n",
      "  Current best observed: 0.5870\n",
      "\n",
      "Iteration 4:\n",
      "  Selected config_id: 1310\n",
      "  Acquisition value: 0.5877\n",
      "  Objective (final_val_cross_entropy): 0.5938\n",
      "  Cost (time): 180.1677\n",
      "  Current best observed: 0.5870\n",
      "\n",
      "Iteration 5:\n",
      "  Selected config_id: 282\n",
      "  Acquisition value: 0.5931\n",
      "  Objective (final_val_cross_entropy): 0.6309\n",
      "  Cost (time): 201.7360\n",
      "  Current best observed: 0.5870\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dim = 7\n",
    "n_iter = 5\n",
    "maximize = False\n",
    "output_standardize = True\n",
    "acq = \"StablePBGI(1e-5)\"\n",
    "seed = 5\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "init_config_id = torch.randint(low=0, high=2000, size=(2*(dim+1),))\n",
    "config_id_history = init_config_id.tolist()\n",
    "print(f\"  Initial config id: {config_id_history}\")\n",
    "x = all_x[init_config_id]\n",
    "y = all_y[init_config_id]\n",
    "c = all_c[init_config_id]\n",
    "best_y_history = [y.min().item()]\n",
    "best_id_history = [config_id_history[y.argmin().item()]]\n",
    "cost_history = [0]\n",
    "\n",
    "# Instead of several separate lists, we initialize a dictionary to store all acquisition histories.\n",
    "acq_history = {\n",
    "    'StablePBGI(1e-5)': [np.nan],\n",
    "    'StablePBGI(1e-6)': [np.nan],\n",
    "    'StablePBGI(1e-7)': [np.nan],\n",
    "    'LogEIC-inv': [np.nan],\n",
    "    'LogEIC-exp': [np.nan],\n",
    "    'regret upper bound': [np.nan]\n",
    "}\n",
    "\n",
    "for i in range(n_iter):\n",
    "    # 1. Fit a GP model on the current data.\n",
    "    model = fit_gp_model(X=x, objective_X=y, cost_X=c, unknown_cost=True, output_standardize=output_standardize)\n",
    "    \n",
    "    # 2. Determine the best observed objective value.\n",
    "    best_f = y.min()\n",
    "        \n",
    "    # 3. Define the acquisition function.\n",
    "    StablePBGI_1e_5 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-5, unknown_cost=True)\n",
    "    StablePBGI_1e_6 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-6, unknown_cost=True)\n",
    "    StablePBGI_1e_7 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-7, unknown_cost=True)\n",
    "    LogEIC_inv = LogExpectedImprovementWithCost(model=model, best_f=best_f, maximize=maximize, unknown_cost=True, inverse_cost=True)\n",
    "    LogEIC_exp = LogExpectedImprovementWithCost(model=model, best_f=best_f, maximize=maximize, unknown_cost=True, inverse_cost=False)\n",
    "    single_outcome_model = fit_gp_model(X=x, objective_X=y, output_standardize=output_standardize)\n",
    "    UCB = UpperConfidenceBound(model=single_outcome_model, maximize=maximize, beta=2 * np.log(dim * ((i + 1) ** 2) * (math.pi ** 2) / (6 * 0.1)) / 5)\n",
    "    LCB = LowerConfidenceBound(model=single_outcome_model, maximize=maximize, beta=2 * np.log(dim * ((i + 1) ** 2) * (math.pi ** 2) / (6 * 0.1)) / 5)\n",
    "\n",
    "    # 4. Evaluate the acquisition function on all candidate x's.\n",
    "    StablePBGI_1e_5_acq = StablePBGI_1e_5.forward(all_x.unsqueeze(1))\n",
    "    StablePBGI_1e_6_acq = StablePBGI_1e_6.forward(all_x.unsqueeze(1))\n",
    "    StablePBGI_1e_6_acq[config_id_history] = y.squeeze(-1)\n",
    "    StablePBGI_1e_7_acq = StablePBGI_1e_7.forward(all_x.unsqueeze(1))\n",
    "    LogEIC_inv_acq = LogEIC_inv.forward(all_x.unsqueeze(1))\n",
    "    LogEIC_exp_acq = LogEIC_exp.forward(all_x.unsqueeze(1))\n",
    "    UCB_acq = UCB.forward(all_x.unsqueeze(1))\n",
    "    LCB_acq = LCB.forward(all_x.unsqueeze(1))\n",
    "\n",
    "    # 5. Record information for stopping.\n",
    "    num_configs = 2000\n",
    "    all_ids = torch.arange(num_configs)\n",
    "    mask = torch.ones(num_configs, dtype=torch.bool)\n",
    "    mask[config_id_history] = False\n",
    "\n",
    "    acq_history['StablePBGI(1e-5)'].append(torch.min(StablePBGI_1e_5_acq[mask]).item())\n",
    "    acq_history['StablePBGI(1e-6)'].append(torch.min(StablePBGI_1e_6_acq[mask]).item())\n",
    "    acq_history['StablePBGI(1e-7)'].append(torch.min(StablePBGI_1e_7_acq[mask]).item())\n",
    "    acq_history['LogEIC-inv'].append(torch.max(LogEIC_inv_acq[mask]).item())\n",
    "    acq_history['LogEIC-exp'].append(torch.max(LogEIC_exp_acq[mask]).item())\n",
    "    acq_history['regret upper bound'].append(torch.min(UCB_acq[~mask]).item() - torch.min(LCB_acq).item())\n",
    "\n",
    "    # 6. Select the candidate with the optimal acquisition value.\n",
    "    candidate_ids = all_ids[mask]\n",
    "    \n",
    "    if acq == \"StablePBGI(1e-5)\":\n",
    "        candidate_acqs = StablePBGI_1e_5_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"StablePBGI(1e-6)\":\n",
    "        candidate_acqs = StablePBGI_1e_6_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"StablePBGI(1e-7)\":\n",
    "        candidate_acqs = StablePBGI_1e_7_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"LogEIC-inv\":\n",
    "        candidate_acqs = LogEIC_inv_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmax(candidate_acqs)]\n",
    "        new_config_acq = torch.max(candidate_acqs)\n",
    "    if acq == \"LogEIC-exp\":\n",
    "        candidate_acqs = LogEIC_exp_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmax(candidate_acqs)]\n",
    "        new_config_acq = torch.max(candidate_acqs)\n",
    "    if acq == \"LCB\":\n",
    "        candidate_acqs = LCB_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "\n",
    "    new_config_x = all_x[new_config_id]\n",
    "    \n",
    "    # 7. Query the objective for the new configuration.\n",
    "    new_config_y = all_y[new_config_id]\n",
    "    new_config_c = all_c[new_config_id]\n",
    "    \n",
    "    # 8. Append the new data to our training set.\n",
    "    x = torch.cat([x, new_config_x.unsqueeze(0)], dim=0)\n",
    "    y = torch.cat([y, new_config_y.unsqueeze(0)], dim=0)\n",
    "    c = torch.cat([c, new_config_c.unsqueeze(0)], dim=0)\n",
    "    config_id_history.append(new_config_id.item())\n",
    "    best_y_history.append(best_f.item())\n",
    "    best_id_history.append(config_id_history[y.argmin().item()])\n",
    "    cost_history.append(new_config_c.item())\n",
    "\n",
    "    print(f\"Iteration {i + 1}:\")\n",
    "    print(f\"  Selected config_id: {new_config_id}\")\n",
    "    print(f\"  Acquisition value: {new_config_acq.item():.4f}\")\n",
    "    print(f\"  Objective (final_val_cross_entropy): {new_config_y.item():.4f}\")\n",
    "    print(f\"  Cost (time): {new_config_c.item():.4f}\")\n",
    "    print(f\"  Current best observed: {best_f.item():.4f}\")\n",
    "    print()\n",
    "\n",
    "best_y_history.append(y.min().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Initial config id: [418, 1152, 1754, 976, 1318, 1426, 1236, 329, 1964, 1498, 1667, 1626, 1460, 1806, 994, 1726]\n",
      "sigma: 0.47894790294582484\n",
      "tensor([0.5814, 0.6808, 0.1071, 0.3232, 0.0000, 0.3438, 0.6232]) tensor([0.9324, 0.1103, 0.5538, 0.7870, 0.6667, 0.1023, 0.3960])\n",
      "delta mu: 0.3827099745140336\n",
      "kappa: 0.1470150784031663\n",
      "kl: 0.0015110223788405985\n",
      "ei diff: 0.00010345388051661284\n",
      "exp min regret gap: 0.3868543677263436\n",
      "\n",
      "Iteration 1:\n",
      "  Selected config_id: 1360\n",
      "  Acquisition value: -7.5494\n",
      "  Objective (final_val_cross_entropy): 0.4736\n",
      "  Cost (time): 209.5318\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.4630851457894287\n",
      "tensor([0.2496, 0.5328, 0.3502, 0.0309, 0.0000, 0.2584, 0.4609]) tensor([0.5814, 0.6808, 0.1071, 0.3232, 0.0000, 0.3438, 0.6232])\n",
      "delta mu: 0.026101114573286233\n",
      "kappa: 0.1913104529816585\n",
      "kl: 0.006238687001181686\n",
      "ei diff: 0.16316874203753642\n",
      "exp min regret gap: 0.19995475264756857\n",
      "\n",
      "Iteration 2:\n",
      "  Selected config_id: 1261\n",
      "  Acquisition value: -7.8306\n",
      "  Objective (final_val_cross_entropy): 0.4400\n",
      "  Cost (time): 232.6493\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.448301737691829\n",
      "tensor([0.7691, 0.5768, 0.1814, 0.0083, 0.0000, 0.2999, 0.4462]) tensor([0.2496, 0.5328, 0.3502, 0.0309, 0.0000, 0.2584, 0.4609])\n",
      "delta mu: 0.05011179828522505\n",
      "kappa: 0.20354299034964107\n",
      "kl: 0.004021742235328918\n",
      "ei diff: 0.20552512112191695\n",
      "exp min regret gap: 0.26476434431870777\n",
      "\n",
      "Iteration 3:\n",
      "  Selected config_id: 542\n",
      "  Acquisition value: -8.0429\n",
      "  Objective (final_val_cross_entropy): 0.4597\n",
      "  Cost (time): 206.2745\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.43616547680440154\n",
      "tensor([0.5078, 0.6852, 0.2385, 0.1161, 0.0000, 0.5905, 0.0822]) tensor([0.7691, 0.5768, 0.1814, 0.0083, 0.0000, 0.2999, 0.4462])\n",
      "delta mu: 0.019432244570317003\n",
      "kappa: 0.2131987657210681\n",
      "kl: 0.0060247078322340775\n",
      "ei diff: 0.17683445187971641\n",
      "exp min regret gap: 0.20796809262661337\n",
      "\n",
      "Iteration 4:\n",
      "  Selected config_id: 1589\n",
      "  Acquisition value: -8.1025\n",
      "  Objective (final_val_cross_entropy): 0.4481\n",
      "  Cost (time): 217.0765\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.4235023712623851\n",
      "tensor([0.5483, 0.8408, 0.4133, 0.2055, 0.0000, 0.8120, 0.7150]) tensor([0.5078, 0.6852, 0.2385, 0.1161, 0.0000, 0.5905, 0.0822])\n",
      "delta mu: 0.012467124389676454\n",
      "kappa: 0.21321744552638755\n",
      "kl: 0.004780769214868585\n",
      "ei diff: 0.19196076000746942\n",
      "exp min regret gap: 0.21485241785880715\n",
      "\n",
      "Iteration 5:\n",
      "  Selected config_id: 212\n",
      "  Acquisition value: -8.1650\n",
      "  Objective (final_val_cross_entropy): 0.4593\n",
      "  Cost (time): 214.0045\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.41076841482375054\n",
      "tensor([0.2853, 0.6656, 0.0386, 0.1613, 0.0000, 0.7638, 0.3580]) tensor([0.5483, 0.8408, 0.4133, 0.2055, 0.0000, 0.8120, 0.7150])\n",
      "delta mu: 0.07415226892634025\n",
      "kappa: 0.21533916089119653\n",
      "kl: 0.00866573813906002\n",
      "ei diff: 0.1161146028608499\n",
      "exp min regret gap: 0.2044414713665612\n",
      "\n",
      "Iteration 6:\n",
      "  Selected config_id: 546\n",
      "  Acquisition value: -8.2431\n",
      "  Objective (final_val_cross_entropy): 0.4536\n",
      "  Cost (time): 230.0875\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3998668837965718\n",
      "tensor([0.6836, 0.9940, 0.0189, 0.2740, 0.0000, 0.6854, 0.3322]) tensor([0.2853, 0.6656, 0.0386, 0.1613, 0.0000, 0.7638, 0.3580])\n",
      "delta mu: 0.043662608044771734\n",
      "kappa: 0.21661723886425355\n",
      "kl: 0.0036444130785479034\n",
      "ei diff: 0.14632775198246797\n",
      "exp min regret gap: 0.1992371675184602\n",
      "\n",
      "Iteration 7:\n",
      "  Selected config_id: 1914\n",
      "  Acquisition value: -8.3460\n",
      "  Objective (final_val_cross_entropy): 0.4677\n",
      "  Cost (time): 189.5788\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3907373364889091\n",
      "tensor([0.3666, 0.9879, 0.0306, 0.0776, 0.0000, 0.2745, 0.7178]) tensor([0.6836, 0.9940, 0.0189, 0.2740, 0.0000, 0.6854, 0.3322])\n",
      "delta mu: 0.020912272481536653\n",
      "kappa: 0.21269066000727932\n",
      "kl: 0.0024273734509266953\n",
      "ei diff: 0.1244541329988062\n",
      "exp min regret gap: 0.15277612416337544\n",
      "\n",
      "Iteration 8:\n",
      "  Selected config_id: 1238\n",
      "  Acquisition value: -8.4154\n",
      "  Objective (final_val_cross_entropy): 0.4424\n",
      "  Cost (time): 227.3685\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.38191748468327086\n",
      "tensor([0.0785, 0.7111, 0.2477, 0.4590, 0.0000, 0.4752, 0.8909]) tensor([0.3666, 0.9879, 0.0306, 0.0776, 0.0000, 0.2745, 0.7178])\n",
      "delta mu: 0.04908915805126152\n",
      "kappa: 0.20962649410971812\n",
      "kl: 0.0017442468690016533\n",
      "ei diff: 0.17937542056772623\n",
      "exp min regret gap: 0.23465521291345007\n",
      "\n",
      "Iteration 9:\n",
      "  Selected config_id: 106\n",
      "  Acquisition value: -8.5452\n",
      "  Objective (final_val_cross_entropy): 0.4872\n",
      "  Cost (time): 247.3439\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3734192051130871\n",
      "tensor([0.4045, 0.3757, 0.0352, 0.0711, 0.0000, 0.3649, 0.9929]) tensor([0.0785, 0.7111, 0.2477, 0.4590, 0.0000, 0.4752, 0.8909])\n",
      "delta mu: 0.02240710183313399\n",
      "kappa: 0.21203719895337692\n",
      "kl: 0.0016082334480366223\n",
      "ei diff: 0.12318290076090348\n",
      "exp min regret gap: 0.15160273125683346\n",
      "\n",
      "Iteration 10:\n",
      "  Selected config_id: 1843\n",
      "  Acquisition value: -8.5853\n",
      "  Objective (final_val_cross_entropy): 0.4709\n",
      "  Cost (time): 222.8751\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3659596000724164\n",
      "tensor([0.7850, 0.8355, 0.0853, 0.0129, 0.0000, 0.9333, 0.4964]) tensor([0.4045, 0.3757, 0.0352, 0.0711, 0.0000, 0.3649, 0.9929])\n",
      "delta mu: 0.01354241409119028\n",
      "kappa: 0.20595049578807378\n",
      "kl: 0.001342332130157109\n",
      "ei diff: 0.1105365564077424\n",
      "exp min regret gap: 0.12941450382345265\n",
      "\n",
      "Iteration 11:\n",
      "  Selected config_id: 1488\n",
      "  Acquisition value: -8.6943\n",
      "  Objective (final_val_cross_entropy): 0.4401\n",
      "  Cost (time): 206.3288\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3589031631736165\n",
      "tensor([0.0175, 0.3225, 0.3931, 0.2925, 0.0000, 0.8018, 0.9301]) tensor([0.7850, 0.8355, 0.0853, 0.0129, 0.0000, 0.9333, 0.4964])\n",
      "delta mu: 0.03601312408505153\n",
      "kappa: 0.2193681494862676\n",
      "kl: 0.0017714177431816802\n",
      "ei diff: 0.16524977550067105\n",
      "exp min regret gap: 0.2077914846793403\n",
      "\n",
      "Iteration 12:\n",
      "  Selected config_id: 1714\n",
      "  Acquisition value: -8.7924\n",
      "  Objective (final_val_cross_entropy): 0.4704\n",
      "  Cost (time): 262.8905\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3527018999067106\n",
      "tensor([0.2644, 0.8300, 0.4610, 0.0070, 0.0000, 0.1610, 0.9549]) tensor([0.0175, 0.3225, 0.3931, 0.2925, 0.0000, 0.8018, 0.9301])\n",
      "delta mu: 0.03098322145655985\n",
      "kappa: 0.22419372537463467\n",
      "kl: 0.001209279179206879\n",
      "ei diff: 0.11340875063768757\n",
      "exp min regret gap: 0.1499047658311535\n",
      "\n",
      "Iteration 13:\n",
      "  Selected config_id: 1058\n",
      "  Acquisition value: -8.8484\n",
      "  Objective (final_val_cross_entropy): 0.4619\n",
      "  Cost (time): 223.3379\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3463831070631085\n",
      "tensor([6.1106e-01, 7.2887e-01, 6.6225e-01, 3.9360e-04, 0.0000e+00, 1.6455e-01,\n",
      "        5.7831e-01]) tensor([0.2644, 0.8300, 0.4610, 0.0070, 0.0000, 0.1610, 0.9549])\n",
      "delta mu: 0.0004557950265051125\n",
      "kappa: 0.224859738673441\n",
      "kl: 0.01117606860029563\n",
      "ei diff: 0.1102848830758253\n",
      "exp min regret gap: 0.12754965317765887\n",
      "\n",
      "Iteration 14:\n",
      "  Selected config_id: 176\n",
      "  Acquisition value: -8.9350\n",
      "  Objective (final_val_cross_entropy): 0.4478\n",
      "  Cost (time): 203.8270\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.34124971762882267\n",
      "tensor([0.9463, 0.8344, 0.3370, 0.1491, 0.0000, 0.7407, 0.1734]) tensor([6.1106e-01, 7.2887e-01, 6.6225e-01, 3.9360e-04, 0.0000e+00, 1.6455e-01,\n",
      "        5.7831e-01])\n",
      "delta mu: 0.004574729902618646\n",
      "kappa: 0.22438025851723786\n",
      "kl: 0.0049796275909471666\n",
      "ei diff: 0.11645143241801328\n",
      "exp min regret gap: 0.1322222960854754\n",
      "\n",
      "Iteration 15:\n",
      "  Selected config_id: 1581\n",
      "  Acquisition value: -9.0174\n",
      "  Objective (final_val_cross_entropy): 0.4530\n",
      "  Cost (time): 202.1708\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3378234730861439\n",
      "tensor([0.6662, 0.4755, 0.8464, 0.1513, 0.0000, 0.6672, 0.9076]) tensor([0.9463, 0.8344, 0.3370, 0.1491, 0.0000, 0.7407, 0.1734])\n",
      "delta mu: 0.0396907292206401\n",
      "kappa: 0.23235590277975948\n",
      "kl: 0.001839071045332541\n",
      "ei diff: 0.15377432740168495\n",
      "exp min regret gap: 0.20051098080031288\n",
      "\n",
      "Iteration 16:\n",
      "  Selected config_id: 123\n",
      "  Acquisition value: -9.0198\n",
      "  Objective (final_val_cross_entropy): 0.4536\n",
      "  Cost (time): 211.1778\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3344808358780559\n",
      "tensor([0.2853, 0.4420, 0.7316, 0.3687, 0.0000, 0.0323, 0.9273]) tensor([0.6662, 0.4755, 0.8464, 0.1513, 0.0000, 0.6672, 0.9076])\n",
      "delta mu: 0.02291788385497573\n",
      "kappa: 0.22995061860983723\n",
      "kl: 0.010370102773065715\n",
      "ei diff: 0.13238292946359426\n",
      "exp min regret gap: 0.1718589367107453\n",
      "\n",
      "Iteration 17:\n",
      "  Selected config_id: 1044\n",
      "  Acquisition value: -9.0151\n",
      "  Objective (final_val_cross_entropy): 0.4782\n",
      "  Cost (time): 233.9185\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.32946009057707465\n",
      "tensor([0.2089, 0.4793, 0.9687, 0.1981, 0.0000, 0.9767, 0.4130]) tensor([0.2853, 0.4420, 0.7316, 0.3687, 0.0000, 0.0323, 0.9273])\n",
      "delta mu: 0.0017314134846653229\n",
      "kappa: 0.23084215493022775\n",
      "kl: 0.0011729830796988194\n",
      "ei diff: 0.12456084513913898\n",
      "exp min regret gap: 0.13188269868820807\n",
      "\n",
      "Iteration 18:\n",
      "  Selected config_id: 1461\n",
      "  Acquisition value: -9.0723\n",
      "  Objective (final_val_cross_entropy): 0.4582\n",
      "  Cost (time): 228.1142\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3240626813733575\n",
      "tensor([0.4919, 0.7520, 0.8425, 0.6312, 0.0000, 0.8592, 0.1838]) tensor([0.2089, 0.4793, 0.9687, 0.1981, 0.0000, 0.9767, 0.4130])\n",
      "delta mu: 0.04420776629861811\n",
      "kappa: 0.22566046554514285\n",
      "kl: 0.008578122949132183\n",
      "ei diff: 0.16631739841140955\n",
      "exp min regret gap: 0.22530387770872806\n",
      "\n",
      "Iteration 19:\n",
      "  Selected config_id: 245\n",
      "  Acquisition value: -9.1203\n",
      "  Objective (final_val_cross_entropy): 0.5023\n",
      "  Cost (time): 199.8534\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3155090073505586\n",
      "tensor([0.4089, 0.3027, 0.7942, 0.7593, 0.0000, 0.8720, 0.7997]) tensor([0.4919, 0.7520, 0.8425, 0.6312, 0.0000, 0.8592, 0.1838])\n",
      "delta mu: 0.005609885779905177\n",
      "kappa: 0.22719815070753357\n",
      "kl: 0.011110422644870499\n",
      "ei diff: 0.134246645168612\n",
      "exp min regret gap: 0.15679035660908386\n",
      "\n",
      "Iteration 20:\n",
      "  Selected config_id: 1311\n",
      "  Acquisition value: -9.1592\n",
      "  Objective (final_val_cross_entropy): 0.5126\n",
      "  Cost (time): 212.9730\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.30842380400374103\n",
      "tensor([0.6869, 0.3728, 0.9578, 0.3632, 0.0000, 0.3237, 0.0064]) tensor([0.4089, 0.3027, 0.7942, 0.7593, 0.0000, 0.8720, 0.7997])\n",
      "delta mu: 0.007413307394365742\n",
      "kappa: 0.2283522519843464\n",
      "kl: 0.0008971800629705706\n",
      "ei diff: 0.12885944833328\n",
      "exp min regret gap: 0.14110924365658206\n",
      "\n",
      "Iteration 21:\n",
      "  Selected config_id: 696\n",
      "  Acquisition value: -9.2608\n",
      "  Objective (final_val_cross_entropy): 0.4777\n",
      "  Cost (time): 211.3651\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.3043602152380548\n",
      "tensor([0.2853, 0.8040, 0.9804, 0.0237, 0.0000, 0.0620, 0.0750]) tensor([0.6869, 0.3728, 0.9578, 0.3632, 0.0000, 0.3237, 0.0064])\n",
      "delta mu: 0.056088529293233824\n",
      "kappa: 0.22583924743484385\n",
      "kl: 0.004428853285321277\n",
      "ei diff: 0.10653465317235443\n",
      "exp min regret gap: 0.1732506568518874\n",
      "\n",
      "Iteration 22:\n",
      "  Selected config_id: 1574\n",
      "  Acquisition value: -9.2184\n",
      "  Objective (final_val_cross_entropy): 0.4448\n",
      "  Cost (time): 234.6392\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.30033235388946533\n",
      "tensor([0.0644, 0.5133, 0.8075, 0.7713, 0.0000, 0.0620, 0.0649]) tensor([0.2853, 0.8040, 0.9804, 0.0237, 0.0000, 0.0620, 0.0750])\n",
      "delta mu: 0.040045577101159524\n",
      "kappa: 0.22343594792788712\n",
      "kl: 0.0017820044707553828\n",
      "ei diff: 0.16199471984396296\n",
      "exp min regret gap: 0.20870978414703725\n",
      "\n",
      "Iteration 23:\n",
      "  Selected config_id: 200\n",
      "  Acquisition value: -9.2241\n",
      "  Objective (final_val_cross_entropy): 0.5134\n",
      "  Cost (time): 257.3999\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.29525776169975776\n",
      "tensor([0.1615, 0.6161, 0.3629, 0.1529, 0.3333, 0.8470, 0.9651]) tensor([0.0644, 0.5133, 0.8075, 0.7713, 0.0000, 0.0620, 0.0649])\n",
      "delta mu: 0.17644648999199314\n",
      "kappa: 0.2221284323156003\n",
      "kl: 9.801600519410147e-05\n",
      "ei diff: 0.039013988855026334\n",
      "exp min regret gap: 0.21701550483978327\n",
      "\n",
      "Iteration 24:\n",
      "  Selected config_id: 322\n",
      "  Acquisition value: -9.3065\n",
      "  Objective (final_val_cross_entropy): 0.4509\n",
      "  Cost (time): 837.7438\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.29173535118808425\n",
      "tensor([0.0340, 0.8739, 0.4443, 0.0080, 0.0000, 0.7738, 0.1854]) tensor([0.1615, 0.6161, 0.3629, 0.1529, 0.3333, 0.8470, 0.9651])\n",
      "delta mu: 0.09527870639077785\n",
      "kappa: 0.22253416660729153\n",
      "kl: 4.9190800482268315e-05\n",
      "ei diff: 0.08303029039336789\n",
      "exp min regret gap: 0.17941262716330986\n",
      "\n",
      "Iteration 25:\n",
      "  Selected config_id: 1863\n",
      "  Acquisition value: -9.3983\n",
      "  Objective (final_val_cross_entropy): 0.4674\n",
      "  Cost (time): 256.3366\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.2877217798288093\n",
      "tensor([0.2496, 0.8723, 0.2791, 0.0597, 1.0000, 0.3214, 0.2219]) tensor([0.0340, 0.8739, 0.4443, 0.0080, 0.0000, 0.7738, 0.1854])\n",
      "delta mu: 0.1989432607672399\n",
      "kappa: 0.22117709606076374\n",
      "kl: 0.001709839756890874\n",
      "ei diff: 0.013521197465343915\n",
      "exp min regret gap: 0.21893145820913476\n",
      "\n",
      "Iteration 26:\n",
      "  Selected config_id: 1400\n",
      "  Acquisition value: -9.3768\n",
      "  Objective (final_val_cross_entropy): 0.3096\n",
      "  Cost (time): 1449.1275\n",
      "  Current best observed: 0.3268\n",
      "\n",
      "sigma: 0.2833401168800212\n",
      "tensor([0.9640, 0.9877, 0.7216, 0.9534, 0.0000, 0.2237, 0.9772]) tensor([0.2496, 0.8723, 0.2791, 0.0597, 1.0000, 0.3214, 0.2219])\n",
      "delta mu: 0.34865623815572616\n",
      "kappa: 0.2172119423993677\n",
      "kl: 0.00011972448860908536\n",
      "ei diff: 0.32725995482323383\n",
      "exp min regret gap: 0.6775967768715225\n",
      "\n",
      "Iteration 27:\n",
      "  Selected config_id: 202\n",
      "  Acquisition value: -9.6276\n",
      "  Objective (final_val_cross_entropy): 0.5262\n",
      "  Cost (time): 209.6998\n",
      "  Current best observed: 0.3096\n",
      "\n",
      "sigma: 0.2796254869785077\n",
      "tensor([0.7802, 0.9801, 0.1104, 0.9641, 0.0000, 0.9699, 0.6928]) tensor([0.9640, 0.9877, 0.7216, 0.9534, 0.0000, 0.2237, 0.9772])\n",
      "delta mu: 0.04964717687917031\n",
      "kappa: 0.2138885225435505\n",
      "kl: 0.024770820743324706\n",
      "ei diff: 0.11497588265043886\n",
      "exp min regret gap: 0.18842666157311327\n",
      "\n",
      "Iteration 28:\n",
      "  Selected config_id: 155\n",
      "  Acquisition value: -9.6350\n",
      "  Objective (final_val_cross_entropy): 0.5269\n",
      "  Cost (time): 210.8690\n",
      "  Current best observed: 0.3096\n",
      "\n",
      "sigma: 0.2750978444966516\n",
      "tensor([0.9063, 0.9425, 0.6751, 0.9518, 0.0000, 0.0981, 0.1125]) tensor([0.7802, 0.9801, 0.1104, 0.9641, 0.0000, 0.9699, 0.6928])\n",
      "delta mu: 0.004264598749018411\n",
      "kappa: 0.2111746387257809\n",
      "kl: 0.0011782282588216342\n",
      "ei diff: 0.10740783508519473\n",
      "exp min regret gap: 0.11679799569474753\n",
      "\n",
      "Iteration 29:\n",
      "  Selected config_id: 180\n",
      "  Acquisition value: -9.7282\n",
      "  Objective (final_val_cross_entropy): 0.5261\n",
      "  Cost (time): 194.0689\n",
      "  Current best observed: 0.3096\n",
      "\n",
      "sigma: 0.2708253369837333\n",
      "tensor([0.9516, 0.4426, 0.7568, 0.9372, 0.0000, 0.2207, 0.6935]) tensor([0.9063, 0.9425, 0.6751, 0.9518, 0.0000, 0.0981, 0.1125])\n",
      "delta mu: 0.013626832963079427\n",
      "kappa: 0.21034124882367955\n",
      "kl: 0.001988948960083836\n",
      "ei diff: 0.0941934625959752\n",
      "exp min regret gap: 0.11445346772180487\n",
      "\n",
      "Iteration 30:\n",
      "  Selected config_id: 1753\n",
      "  Acquisition value: -9.8481\n",
      "  Objective (final_val_cross_entropy): 0.5279\n",
      "  Cost (time): 198.0716\n",
      "  Current best observed: 0.3096\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dim = 7\n",
    "n_iter = 30\n",
    "maximize = False\n",
    "output_standardize = True\n",
    "acq = \"LogEIC-inv\"\n",
    "seed = 13\n",
    "\n",
    "# Sample initial configurations\n",
    "torch.manual_seed(seed)\n",
    "init_config_id = torch.randint(low=0, high=2000, size=(2*(dim+1),))\n",
    "config_id_history = init_config_id.tolist()\n",
    "print(f\"  Initial config id: {config_id_history}\")\n",
    "x = all_x[init_config_id]\n",
    "y = all_y[init_config_id]\n",
    "c = all_c[init_config_id]\n",
    "best_y_history = [y.min().item()]\n",
    "best_id_history = [config_id_history[y.argmin().item()]]\n",
    "cost_history = [0]\n",
    "\n",
    "old_model = fit_gp_model(X=x[:-1], objective_X=y[:-1], output_standardize=output_standardize)\n",
    "old_config_x = x[-1]\n",
    "\n",
    "acq_history = {\n",
    "    'exp min regret gap': [np.nan],\n",
    "    'delta_mu': [np.nan],\n",
    "    'kappa': [np.nan],\n",
    "    'kl': [np.nan],\n",
    "    'ei_diff': [np.nan]\n",
    "}\n",
    "\n",
    "for i in range(n_iter):\n",
    "    # 1. Fit a GP model on the current data.\n",
    "    model = fit_gp_model(X=x, objective_X=y, cost_X=c, unknown_cost=True, output_standardize=output_standardize)\n",
    "    \n",
    "    # 2. Determine the best observed objective value.\n",
    "    best_f = y.min()\n",
    "        \n",
    "    # 3. Define the acquisition function.\n",
    "    StablePBGI_1e_5 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-5, unknown_cost=True)\n",
    "    StablePBGI_1e_6 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-6, unknown_cost=True)\n",
    "    StablePBGI_1e_7 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-7, unknown_cost=True)\n",
    "    LogEIC_inv = LogExpectedImprovementWithCost(model=model, best_f=best_f, maximize=maximize, unknown_cost=True, inverse_cost=True)\n",
    "    LogEIC_exp = LogExpectedImprovementWithCost(model=model, best_f=best_f, maximize=maximize, unknown_cost=True, inverse_cost=False)\n",
    "    single_outcome_model = fit_gp_model(X=x, objective_X=y, output_standardize=output_standardize)\n",
    "    print(\"sigma:\", single_outcome_model.posterior(all_x.unsqueeze(1)).variance.sqrt().squeeze(-1).max().item())\n",
    "    beta = 2 * np.log(dim * ((i + 1) ** 2) * (math.pi ** 2) / (6 * 0.1)) / 5\n",
    "    UCB = UpperConfidenceBound(model=single_outcome_model, maximize=maximize, beta=beta)\n",
    "    LCB = LowerConfidenceBound(model=single_outcome_model, maximize=maximize, beta=beta)\n",
    "\n",
    "    # 4. Evaluate the acquisition function on all candidate x's.\n",
    "    StablePBGI_1e_5_acq = StablePBGI_1e_5.forward(all_x.unsqueeze(1))\n",
    "    StablePBGI_1e_5_acq[config_id_history] = y.squeeze(-1)\n",
    "    StablePBGI_1e_6_acq = StablePBGI_1e_6.forward(all_x.unsqueeze(1))\n",
    "    StablePBGI_1e_6_acq[config_id_history] = y.squeeze(-1)\n",
    "    StablePBGI_1e_7_acq = StablePBGI_1e_7.forward(all_x.unsqueeze(1))\n",
    "    StablePBGI_1e_7_acq[config_id_history] = y.squeeze(-1)\n",
    "    LogEIC_inv_acq = LogEIC_inv.forward(all_x.unsqueeze(1))\n",
    "    LogEIC_exp_acq = LogEIC_exp.forward(all_x.unsqueeze(1))\n",
    "    UCB_acq = UCB.forward(all_x.unsqueeze(1))\n",
    "    LCB_acq = LCB.forward(all_x.unsqueeze(1))\n",
    "\n",
    "    # 5. Select the candidate with the optimal acquisition value.\n",
    "    num_configs = 2000\n",
    "    all_ids = torch.arange(num_configs)\n",
    "    mask = torch.ones(num_configs, dtype=torch.bool)\n",
    "    mask[config_id_history] = False\n",
    "    candidate_ids = all_ids[mask]\n",
    "    \n",
    "    if acq == \"StablePBGI(1e-5)\":\n",
    "        candidate_acqs = StablePBGI_1e_5_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"StablePBGI(1e-6)\":\n",
    "        candidate_acqs = StablePBGI_1e_6_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"StablePBGI(1e-7)\":\n",
    "        candidate_acqs = StablePBGI_1e_7_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"LogEIC-inv\":\n",
    "        candidate_acqs = LogEIC_inv_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmax(candidate_acqs)]\n",
    "        new_config_acq = torch.max(candidate_acqs)\n",
    "    if acq == \"LogEIC-exp\":\n",
    "        candidate_acqs = LogEIC_exp_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmax(candidate_acqs)]\n",
    "        new_config_acq = torch.max(candidate_acqs)\n",
    "    if acq == \"LCB\":\n",
    "        candidate_acqs = LCB_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "\n",
    "    new_config_x = all_x[new_config_id]\n",
    "    \n",
    "    # 6. Query the objective for the new configuration.\n",
    "    new_config_y = all_y[new_config_id]\n",
    "    new_config_c = all_c[new_config_id]\n",
    "\n",
    "    # 7. Record information for stopping.\n",
    "\n",
    "    # 7.1. Get the posterior mean for old and new GPs at the new and old best points.\n",
    "    # new_config_x and old_config_x should be the configurations corresponding to the current\n",
    "    # and previous best indices, respectively.\n",
    "    x_pair = torch.stack([new_config_x, old_config_x])\n",
    "    print(new_config_x, old_config_x)\n",
    "\n",
    "    # 7.2. Get posterior mean and covariance from the new model.\n",
    "    new_posterior = single_outcome_model.posterior(x_pair)\n",
    "    new_mean = new_posterior.mean         # Shape: [2]\n",
    "    new_covar = new_posterior.mvn.covariance_matrix     # Shape: [2, 2]\n",
    "\n",
    "    # 7.3. Get posterior mean and covariance from the old model.\n",
    "    old_posterior = old_model.posterior(x_pair)\n",
    "    old_mean = old_posterior.mean           # Shape: [2]\n",
    "    old_covar = old_posterior.mvn.covariance_matrix       # Shape: [2, 2]\n",
    "\n",
    "    # 7.4. Compute delta_mu (the absolute change in best posterior mean)\n",
    "    # Here, we assume that new_config_x corresponds to the current best (new point)\n",
    "    # and old_config_x corresponds to the previous best.\n",
    "    delta_mu = abs(old_mean[1].item() - new_mean[0].item())\n",
    "    acq_history['delta_mu'].append(delta_mu)\n",
    "\n",
    "    # 7.5. Compute κ_{t−1} = UCB - LCB gap.\n",
    "    kappa = torch.min(UCB_acq[~mask]) - torch.min(LCB_acq)\n",
    "    acq_history['kappa'].append(kappa)\n",
    "\n",
    "    # 7.6. Compute KL divergence between old and new posteriors at the new point.\n",
    "    old_var = old_covar[0, 0].clamp(min=1e-12)\n",
    "    new_var = new_covar[0, 0].clamp(min=1e-12)\n",
    "    old_mu_val = old_mean[0]\n",
    "    new_mu_val = new_mean[0]\n",
    "    kl = 0.5 * (torch.log(new_var / old_var) +\n",
    "                (old_var + (old_mu_val - new_mu_val).pow(2)) / new_var - 1).item()\n",
    "    acq_history['kl'].append(kl)\n",
    "\n",
    "    # 7.7. Compute ei_diff, the expected-improvement gap difference.\n",
    "    # If new_config_x and old_config_x are (approximately) equal, we set ei_diff to zero.\n",
    "    if not torch.allclose(new_config_x, old_config_x, atol=1e-6):\n",
    "        # We use the new model's posterior for these two points.\n",
    "        # new_mean and new_covar already contain the predictions.\n",
    "        # Compute the difference in means:\n",
    "        g = (new_mean[0] - new_mean[1]).item()\n",
    "        # Compute the effective variance difference: sigma_new[0,0] - 2*sigma_new[0,1] + sigma_new[1,1]\n",
    "        diff_var = (new_covar[0, 0] - 2 * new_covar[0, 1] + new_covar[1, 1]).item()\n",
    "        if diff_var < 0:\n",
    "            beta_val = 0.0\n",
    "            pdf_val = np.sqrt(1.0 / (2 * np.pi))\n",
    "            cdf_val = 1.0\n",
    "        else:\n",
    "            beta_val = np.sqrt(diff_var)\n",
    "            u = g / beta_val if beta_val > 0 else 0.0\n",
    "            pdf_val = norm.pdf(u)\n",
    "            cdf_val = norm.cdf(u)\n",
    "        ei_diff = beta_val * pdf_val + g * cdf_val\n",
    "    else:\n",
    "        ei_diff = 0.0\n",
    "    acq_history['ei_diff'].append(ei_diff)\n",
    "\n",
    "    print(\"delta mu:\", delta_mu)\n",
    "    print(\"kappa:\", kappa.item())\n",
    "    print(\"kl:\", kl)\n",
    "    print(\"ei diff:\", ei_diff)\n",
    "\n",
    "    # 7.8. Final expression for ΔR̃_t (the expected minimal regret gap).\n",
    "    exp_min_regret_gap = delta_mu + ei_diff + kappa.item() * np.sqrt(0.5 * kl)\n",
    "    print(\"exp min regret gap:\", exp_min_regret_gap)\n",
    "    print()\n",
    "    acq_history['exp min regret gap'].append(exp_min_regret_gap)\n",
    "\n",
    "    # 7.9. Reassign old_model and old_config_x for the next iteration.\n",
    "    old_model = single_outcome_model\n",
    "    old_config_x = new_config_x\n",
    "    \n",
    "    # 8. Append the new data to our training set.\n",
    "    x = torch.cat([x, new_config_x.unsqueeze(0)], dim=0)\n",
    "    y = torch.cat([y, new_config_y.unsqueeze(0)], dim=0)\n",
    "    c = torch.cat([c, new_config_c.unsqueeze(0)], dim=0)\n",
    "    config_id_history.append(new_config_id.item())\n",
    "    best_y_history.append(best_f.item())\n",
    "    best_id_history.append(config_id_history[y.argmin().item()])\n",
    "    cost_history.append(new_config_c.item())\n",
    "\n",
    "    print(f\"Iteration {i + 1}:\")\n",
    "    print(f\"  Selected config_id: {new_config_id}\")\n",
    "    print(f\"  Acquisition value: {new_config_acq.item():.4f}\")\n",
    "    print(f\"  Objective (final_val_cross_entropy): {new_config_y.item():.4f}\")\n",
    "    print(f\"  Cost (time): {new_config_c.item():.4f}\")\n",
    "    print(f\"  Current best observed: {best_f.item():.4f}\")\n",
    "    print()\n",
    "\n",
    "best_y_history.append(y.min().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cost_function(tensor):\n",
    "    costs = []\n",
    "    if tensor.dim() == 1:\n",
    "        tensor = tensor.unsqueeze(0)\n",
    "    for x in tensor:\n",
    "        model_param = bench.query(dataset_name, \"model_parameters\", x2id[x.numpy().tobytes()])\n",
    "        costs.append(0.001*model_param)\n",
    "    return torch.tensor(costs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 7\n",
    "n_iter = 50\n",
    "maximize = False\n",
    "output_standardize = True\n",
    "acq = \"LogEIC\"\n",
    "seed = 5\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "init_config_id = torch.randint(low=0, high=2000, size=(2*(dim+1),))\n",
    "config_id_history = init_config_id.tolist()\n",
    "print(f\"  Initial config id: {config_id_history}\")\n",
    "x = all_x[init_config_id]\n",
    "y = all_y[init_config_id]\n",
    "c = all_c[init_config_id]\n",
    "best_y_history = [y.min().item()]\n",
    "best_id_history = [config_id_history[y.argmin().item()]]\n",
    "cost_history = [0]\n",
    "LogEIC_history = []\n",
    "\n",
    "# Instead of several separate lists, we initialize a dictionary to store all acquisition histories.\n",
    "acq_history = {\n",
    "    'StablePBGI(1e-5)': [np.nan],\n",
    "    'StablePBGI(1e-6)': [np.nan],\n",
    "    'StablePBGI(1e-7)': [np.nan],\n",
    "    'LogEIC': [np.nan],\n",
    "    'regret upper bound': [np.nan]\n",
    "}\n",
    "\n",
    "for i in range(n_iter):\n",
    "    # 1. Fit a GP model on the current data.\n",
    "    model = fit_gp_model(X=x, objective_X=y, output_standardize=output_standardize)\n",
    "    \n",
    "    # 2. Determine the best observed objective value.\n",
    "    best_f = y.min()\n",
    "        \n",
    "    # 3. Define the acquisition function.\n",
    "    StablePBGI_1e_5 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-5, cost=cost_function)\n",
    "    StablePBGI_1e_6 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-6, cost=cost_function)\n",
    "    StablePBGI_1e_7 = StableGittinsIndex(model=model, maximize=maximize, lmbda=1e-7, cost=cost_function)\n",
    "    LogEIC = LogExpectedImprovementWithCost(model=model, best_f=best_f, maximize=maximize, cost=cost_function)\n",
    "    UCB = UpperConfidenceBound(model=model, maximize=maximize, beta=2 * np.log(dim * ((i + 1) ** 2) * (math.pi ** 2) / (6 * 0.1)) / 5)\n",
    "    LCB = LowerConfidenceBound(model=model, maximize=maximize, beta=2 * np.log(dim * ((i + 1) ** 2) * (math.pi ** 2) / (6 * 0.1)) / 5)\n",
    "\n",
    "    # 4. Evaluate the acquisition function on all candidate x's.\n",
    "    StablePBGI_1e_5_acq = StablePBGI_1e_5.forward(all_x.unsqueeze(1))\n",
    "    StablePBGI_1e_5_acq[config_id_history] = y.squeeze(-1)\n",
    "    StablePBGI_1e_6_acq = StablePBGI_1e_6.forward(all_x.unsqueeze(1))\n",
    "    StablePBGI_1e_6_acq[config_id_history] = y.squeeze(-1)\n",
    "    StablePBGI_1e_7_acq = StablePBGI_1e_7.forward(all_x.unsqueeze(1))\n",
    "    StablePBGI_1e_7_acq[config_id_history] = y.squeeze(-1)\n",
    "    LogEIC_acq = LogEIC.forward(all_x.unsqueeze(1))\n",
    "    UCB_acq = UCB.forward(all_x.unsqueeze(1))\n",
    "    LCB_acq = LCB.forward(all_x.unsqueeze(1))\n",
    "\n",
    "    # 5. Record information for stopping.\n",
    "    all_ids = torch.arange(num_configs)\n",
    "    mask = torch.ones(num_configs, dtype=torch.bool)\n",
    "    mask[config_id_history] = False\n",
    "\n",
    "    acq_history['StablePBGI(1e-5)'].append(torch.min(StablePBGI_1e_5_acq[mask]).item())\n",
    "    acq_history['StablePBGI(1e-6)'].append(torch.min(StablePBGI_1e_6_acq[mask]).item())\n",
    "    acq_history['StablePBGI(1e-7)'].append(torch.min(StablePBGI_1e_7_acq[mask]).item())\n",
    "    acq_history['LogEIC'].append(torch.max(LogEIC_acq[mask]).item())\n",
    "    acq_history['regret upper bound'].append(torch.min(UCB_acq[~mask]).item() - torch.min(LCB_acq).item())\n",
    "\n",
    "    # 6. Select the candidate with the optimal acquisition value.\n",
    "    candidate_ids = all_ids[mask]\n",
    "    \n",
    "    if acq == \"StablePBGI(1e-5)\":\n",
    "        candidate_acqs = StablePBGI_1e_5_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"StablePBGI(1e-6)\":\n",
    "        candidate_acqs = StablePBGI_1e_6_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"StablePBGI(1e-7)\":\n",
    "        candidate_acqs = StablePBGI_1e_7_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "    if acq == \"LogEIC\":\n",
    "        candidate_acqs = LogEIC_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmax(candidate_acqs)]\n",
    "        new_config_acq = torch.max(candidate_acqs)\n",
    "    if acq == \"LCB\":\n",
    "        candidate_acqs = LCB_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "\n",
    "    new_config_x = all_x[new_config_id]\n",
    "    \n",
    "    # 7. Query the objective for the new configuration.\n",
    "    new_config_y = all_y[new_config_id]\n",
    "    new_config_c = all_c[new_config_id]\n",
    "    \n",
    "    # 8. Append the new data to our training set.\n",
    "    x = torch.cat([x, new_config_x.unsqueeze(0)], dim=0)\n",
    "    y = torch.cat([y, new_config_y.unsqueeze(0)], dim=0)\n",
    "    c = torch.cat([c, new_config_c.unsqueeze(0)], dim=0)\n",
    "    config_id_history.append(new_config_id.item())\n",
    "    best_y_history.append(best_f.item())\n",
    "    best_id_history.append(config_id_history[y.argmin().item()])\n",
    "    cost_history.append(new_config_c.item())\n",
    "\n",
    "    print(f\"Iteration {i + 1}:\")\n",
    "    print(f\"  Selected config_id: {new_config_id}\")\n",
    "    print(f\"  Acquisition value: {new_config_acq.item():.4f}\")\n",
    "    print(f\"  Objective (final_val_cross_entropy): {new_config_y.item():.4f}\")\n",
    "    print(f\"  Cost (time): {new_config_c.item():.4f}\")\n",
    "    print(f\"  Current best observed: {best_f.item():.4f}\")\n",
    "    print()\n",
    "\n",
    "best_y_history.append(y.min().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch.sampling.pathwise import draw_matheron_paths\n",
    "import math\n",
    "\n",
    "dim = 7\n",
    "bounds = torch.stack([torch.zeros(dim), torch.ones(dim)])\n",
    "n_iter = 50\n",
    "maximize = False\n",
    "output_standardize = True\n",
    "acq = \"TS\"\n",
    "seed = 5\n",
    "num_configs = 2000\n",
    "epsilon = 0.005\n",
    "num_samples = 64\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "init_config_id = torch.randint(low=0, high=num_configs, size=(2*(dim+1),))\n",
    "config_id_history = init_config_id.tolist()\n",
    "print(f\"  Initial config id: {config_id_history}\")\n",
    "x = all_x[init_config_id]\n",
    "y = all_y[init_config_id]\n",
    "c = all_c[init_config_id]\n",
    "best_y_history = [y.min().item()]\n",
    "best_id_history = [config_id_history[y.argmin().item()]]\n",
    "cost_history = [0]\n",
    "LogEIC_history = []\n",
    "\n",
    "# Instead of several separate lists, we initialize a dictionary to store all acquisition histories.\n",
    "acq_history = {\n",
    "    'PRB': [np.nan]\n",
    "}\n",
    "\n",
    "# Independent seed for Thompson sampling\n",
    "ts_seed  = seed + 1\n",
    "\n",
    "for i in range(n_iter):\n",
    "    # 1. Fit a GP model on the current data.\n",
    "    model = fit_gp_model(X=x, objective_X=y, output_standardize=output_standardize)\n",
    "    \n",
    "    # 2. Determine the best observed objective value.\n",
    "    best_f = y.min()\n",
    "        \n",
    "    # 3. Define the acquisition function.\n",
    "    LCB = LowerConfidenceBound(model=model, maximize=maximize, beta=2 * np.log(dim * ((i + 1) ** 2) * (math.pi ** 2) / (6 * 0.1)) / 5)\n",
    "\n",
    "    # 4. Evaluate the acquisition function on all candidate x's.\n",
    "    LCB_acq = LCB.forward(all_x.unsqueeze(1))\n",
    "\n",
    "    # 5. Record information for stopping.\n",
    "    all_ids = torch.arange(num_configs)\n",
    "    mask = torch.ones(num_configs, dtype=torch.bool)\n",
    "    mask[config_id_history] = False\n",
    "\n",
    "    # Probabilistic regret bound\n",
    "    paths = draw_matheron_paths(model, sample_shape=torch.Size([num_samples]))\n",
    "    best_x = all_x[config_id_history[y.argmin().item()]]\n",
    "    regrets = paths(best_x.unsqueeze(0)).squeeze(-1) - paths(all_x).min(dim=1).values\n",
    "    prb_estimate = (regrets <= epsilon).float().mean().item()\n",
    "    acq_history['PRB'].append(prb_estimate)\n",
    "    num_samples = min(math.ceil(num_samples * 1.5), 1000)\n",
    "\n",
    "    # 6. Select the candidate with the optimal acquisition value.\n",
    "    candidate_ids = all_ids[mask]\n",
    "    \n",
    "    if acq == \"LCB\":\n",
    "        candidate_acqs = LCB_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "\n",
    "    if acq == \"TS\":\n",
    "        prev_state = torch.get_rng_state()\n",
    "        torch.manual_seed(ts_seed)\n",
    "        sample_path = draw_matheron_paths(model, sample_shape=torch.Size([1]))\n",
    "        torch.set_rng_state(prev_state)\n",
    "        TS_acq = sample_path(all_x).squeeze()\n",
    "        candidate_acqs = TS_acq[mask]\n",
    "        new_config_id = candidate_ids[torch.argmin(candidate_acqs)]\n",
    "        new_config_acq = torch.min(candidate_acqs)\n",
    "\n",
    "    new_config_x = all_x[new_config_id]\n",
    "    \n",
    "    # 7. Query the objective for the new configuration.\n",
    "    new_config_y = all_y[new_config_id]\n",
    "    new_config_c = all_c[new_config_id]\n",
    "    \n",
    "    # 8. Append the new data to our training set.\n",
    "    x = torch.cat([x, new_config_x.unsqueeze(0)], dim=0)\n",
    "    y = torch.cat([y, new_config_y.unsqueeze(0)], dim=0)\n",
    "    c = torch.cat([c, new_config_c.unsqueeze(0)], dim=0)\n",
    "    config_id_history.append(new_config_id.item())\n",
    "    best_y_history.append(best_f.item())\n",
    "    best_id_history.append(config_id_history[y.argmin().item()])\n",
    "    cost_history.append(new_config_c.item())\n",
    "\n",
    "    print(f\"Iteration {i + 1}:\")\n",
    "    print(f\"  Selected config_id: {new_config_id}\")\n",
    "    print(f\"  Acquisition value: {new_config_acq.item():.4f}\")\n",
    "    print(f\"  Objective (final_val_cross_entropy): {new_config_y.item():.4f}\")\n",
    "    print(f\"  Cost (time): {new_config_c.item():.4f}\")\n",
    "    print(f\"  Current best observed: {best_f.item():.4f}\")\n",
    "    print()\n",
    "\n",
    "best_y_history.append(y.min().item())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "automl_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
