{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b5831947-283e-4682-aae4-bd19bcce03e0",
   "metadata": {},
   "source": [
    "# Tutorial on Feasible trust Region Bayesian Optimization (FuRBO) without restarts\r\n",
    "\r\n",
    "This tutorial shows how to implement Feasible trust Region Bayesian Optimization (FuRBO) wiut restarts in a closed loop using BoTorch.\r\n",
    "\r\n",
    "In this tutorial, we optimize the 10D Ackley function on the domain $[−5,10]^{10}$ subject to two constraint functions $c1$ and $c2$. The problem maximizes the Ackley function while the constraints are fulfilled when $c1x) \\\\leq 0$ and $c2x) \\\\leq 0$.\r\n",
    "\r\n",
    "Since FuRBO is based on Scalable Constrained Bayesian Optimization (SCBO), this tutorial shares part of the same code as the SCBO Tutorial (https://botorch.org/docs/tutorials/scalable_constrained\n",
    "\n",
    "*Note that although several trust regions are supported by the code, the feature is still yet to be tested.*_bo/)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "762be478-50e4-4af4-aa6d-d3b566649c5e",
   "metadata": {},
   "source": [
    "### Objective function\r\n",
    "\r\n",
    "Start by defining the 10D Ackley function for evaluation during the optimization loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "890f1a54-b6cf-4af4-9bfb-1835b2f737a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from botorch.test_functions import Ackley\n",
    "from botorch.utils.transforms import unnormalize\n",
    "\n",
    "class ack():\n",
    "    \n",
    "    def __init__(self, dim, negate, **tkwargs):\n",
    "        \n",
    "        self.fun = Ackley(dim = dim, negate = negate).to(**tkwargs)\n",
    "        self.fun.bounds[0, :].fill_(-5)\n",
    "        self.fun.bounds[1, :].fill_(10)\n",
    "        self.dim = self.fun.dim\n",
    "        self.lb, self.ub = self.fun.bounds\n",
    "        \n",
    "    def eval_(self, x):\n",
    "        \"\"\"This is a helper function we use to unnormalize and evalaute a point\"\"\"\n",
    "        return self.fun(unnormalize(x, [self.lb, self.ub]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b710672-51d0-4fc5-a3e2-ffb3da6f5649",
   "metadata": {},
   "source": [
    "### Constraint functions\r\n",
    "\r\n",
    "Define two constraint functions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64122b23-fc01-4e1a-94de-2fba4839fe72",
   "metadata": {},
   "source": [
    "\n",
    "1. Enforce the sum(x) $\\leq 0$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "228d816a-6452-4078-b3fd-6a42569237c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class sum_():\n",
    "    def __init__(self, threshold, lb, ub):\n",
    "        \n",
    "        self.lb = lb\n",
    "        self.ub = ub\n",
    "        self.threshold = threshold\n",
    "        return \n",
    "    \n",
    "    def c(self, x):\n",
    "        \"\"\"This is a helper function we use to unnormalize and evaluate a point\"\"\"\n",
    "        return x.sum() - self.threshold\n",
    "    \n",
    "    def eval_(self, x):\n",
    "        return self.c(unnormalize(x, [self.lb, self.ub]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08472c8a-67de-4206-bf0f-871e40edfe7c",
   "metadata": {},
   "source": [
    "2. Enforce the l2norm(x) $\\leq 0.5$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e0f4e8d-657f-49d8-bb59-eb8c2ff4a9f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class norm_():\n",
    "    def __init__(self, threshold, lb, ub):\n",
    "        \n",
    "        self.lb = lb\n",
    "        self.ub = ub\n",
    "        self.threshold = threshold\n",
    "        return \n",
    "    \n",
    "    def c(self, x):\n",
    "        return torch.norm(x, p=2) - self.threshold\n",
    "    \n",
    "    def eval_(self, x):\n",
    "        \"\"\"This is a helper function we use to unnormalize and evaluate a point\"\"\"\n",
    "        return self.c(unnormalize(x, [self.lb, self.ub]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3c6e1cd-15de-4985-ac6c-bf96c6eafc24",
   "metadata": {},
   "source": [
    "### Define FuRBO Class\r\n",
    "Define a class to hold the information needed for the optimization. The state is updated with the samples evaluated to update the trust region at each iteration. Prior to the class, two utility functions to identify the current best sample and to fit a GPR The class features a function to reset the status when restarting. Notice that before resetting, the samples previously evaluated need to be extracted and saved outside the state class here declared (see main optimization loop).."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "020338e2-9eaf-49cf-9e5c-fd00e1a46835",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gpytorch\n",
    "import numpy as np\n",
    "\n",
    "from botorch.fit import fit_gpytorch_mll\n",
    "from botorch.models import SingleTaskGP\n",
    "from botorch.models.transforms.outcome import Standardize\n",
    "\n",
    "from gpytorch.constraints import Interval\n",
    "from gpytorch.kernels import MaternKernel, ScaleKernel\n",
    "from gpytorch.likelihoods import GaussianLikelihood\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "\n",
    "from scipy.stats import invgauss\n",
    "from scipy.stats import ecdf\n",
    "\n",
    "from torch import Tensor\n",
    "\n",
    "def get_best_index_for_batch(n_tr, Y: Tensor, C: Tensor):\n",
    "    \"\"\"Return the index for the best point. One for each trust region.\n",
    "    For reference, see https://botorch.org/docs/tutorials/scalable_constrained_bo/\"\"\"\n",
    "    is_feas = (C <= 0).all(dim=-1)\n",
    "    if is_feas.any():  # Choose best feasible candidate\n",
    "        score = Y.clone()\n",
    "        score[~is_feas] = -float(\"inf\")\n",
    "        return torch.topk(score.reshape(-1), k=n_tr).indices\n",
    "    return torch.topk(C.clamp(min=0).sum(dim=-1), k=n_tr, largest=False).indices # Return smallest violation\n",
    "\n",
    "def get_fitted_model(X,\n",
    "                     Y,\n",
    "                     dim,\n",
    "                     max_cholesky_size):\n",
    "    '''Function to fit a GPR to a given set of data.\n",
    "    For reference, see https://botorch.org/docs/tutorials/scalable_constrained_bo/'''\n",
    "    likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))\n",
    "    covar_module = ScaleKernel(  # Use the same lengthscale prior as in the TuRBO paper\n",
    "        MaternKernel(nu=2.5, ard_num_dims=dim, lengthscale_constraint=Interval(0.005, 4.0))\n",
    "    )\n",
    "    model = SingleTaskGP(\n",
    "        X,\n",
    "        Y,\n",
    "        covar_module=covar_module,\n",
    "        likelihood=likelihood,\n",
    "        outcome_transform=Standardize(m=1),\n",
    "    )\n",
    "    mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
    "\n",
    "    with gpytorch.settings.max_cholesky_size(max_cholesky_size):\n",
    "        fit_gpytorch_mll(mll, \n",
    "                         optimizer_kwargs={'method': 'L-BFGS-B'})\n",
    "\n",
    "    return model\n",
    "\n",
    "from botorch.models.model_list_gp_regression import ModelListGP\n",
    "from torch.quasirandom import SobolEngine\n",
    "\n",
    "class Furbo_state_restart():\n",
    "    '''Class to track optimization status with restart'''\n",
    "    # Initialization of the status\n",
    "    def __init__(self,              #\n",
    "                 obj,               # Objective function\n",
    "                 cons,              # Constraints function\n",
    "                 batch_size,        # Batch size of each iteration\n",
    "                 n_init,            # Number of initial points to evaluate\n",
    "                 n_iteration,       # Number of total iterations\n",
    "                 tr_number,         # number of Trust regions\n",
    "                 **tkwargs):\n",
    "        \n",
    "        # Objective function handle\n",
    "        self.obj = obj\n",
    "        \n",
    "        # Constraints function handle\n",
    "        self.cons = cons\n",
    "        \n",
    "        # Domain bounds\n",
    "        self.lb = obj.lb\n",
    "        self.ub = obj.ub\n",
    "        \n",
    "        # Problem dimensions\n",
    "        self.batch_size: int = batch_size      # Dimension of the batch at each iteration\n",
    "        self.n_init: int = n_init              # Number of initial samples\n",
    "        self.dim: int = obj.dim                # Dimension of the problem\n",
    "        \n",
    "        # Trust regions information\n",
    "        self.tr_number: int = tr_number                                            # Number of trust regions to use during evolution\n",
    "        self.tr_ub: float = torch.ones((self.tr_number, self.dim), **tkwargs)      # Upper bounds of trust region\n",
    "        self.tr_lb: float = torch.zeros((self.tr_number, self.dim), **tkwargs)     # Lower bounds of trust region\n",
    "        self.tr_vol: float = torch.prod(self.tr_ub - self.tr_lb, dim=1)            # Volume of trust region\n",
    "        self.radius: float = 1.0                                                   # Percentage around which the trust region is built\n",
    "        self.radius_min: float = 0.5**7                                            # Minimum percentage for trust region\n",
    "\n",
    "        # Trust region updating \n",
    "        self.failure_counter: int = 0       # Counter of failure points to asses how algorithm is going\n",
    "        self.success_counter: int = 0       # Counter of success points to asses how algorithm is going\n",
    "        self.success_tolerance: int = 2     # Success tolerance for \n",
    "        self.failure_tolerance: int = 3     # Failure tolerance for\n",
    "        \n",
    "        # Tensor to save current batch information\n",
    "        self.batch_X: Tensor        # Current batch to evaluate: X values\n",
    "        self.batch_Y: Tensor        # Current batch to evaluate: Y value\n",
    "        self.batch_C: Tensor        # Current batch to evaluate: C values\n",
    "            \n",
    "        # Stopping criteria information\n",
    "        self.n_iteration: int = n_iteration     # Maximum number of iterations allowed\n",
    "        self.it_counter: int = 0  # Counter of iterations for stopping\n",
    "        self.finish_trigger: bool = False       # Trigger to stop optimization\n",
    "        self.failed_GP : bool = False           # Flag to pass to failed_GP in FuRBORestart\n",
    "        \n",
    "        # Restart criteria information\n",
    "        self.restart_trigger: bool = False\n",
    "        \n",
    "        # Sobol sampler engine\n",
    "        self.sobol = SobolEngine(dimension=self.dim, scramble=True)\n",
    "        \n",
    "    # Update the status\n",
    "    def update(self,\n",
    "               X_next,          # Samples X (input values) to update the status\n",
    "               Y_next,          # Samples Y (objective value) to update the status\n",
    "               C_next,          # Samples C (constraints values) to update the status\n",
    "               **tkwargs):\n",
    "        '''Function to update optimization status'''\n",
    "        \n",
    "        # Merge current batch with previously evaluated samples\n",
    "        if not hasattr(self, 'X'):\n",
    "            # If there are no previous samples, declare the Tensors\n",
    "            self.X = X_next\n",
    "            self.Y = Y_next\n",
    "            self.C = C_next\n",
    "        else:\n",
    "            # Else, concatenate the new batch to the previous samples\n",
    "            self.X = torch.cat((self.X, X_next), dim=0)\n",
    "            self.Y = torch.cat((self.Y, Y_next), dim=0)\n",
    "            self.C = torch.cat((self.C, C_next), dim=0)\n",
    "\n",
    "        # update GPR surrogates\n",
    "        try:\n",
    "            self.Y_model = get_fitted_model(self.X, self.Y, self.dim, max_cholesky_size = float(\"inf\"))\n",
    "            self.C_model = ModelListGP(*[get_fitted_model(self.X, C.reshape([C.shape[0],1]), self.dim, max_cholesky_size = float(\"inf\")) for C in self.C.t()])\n",
    "        except:\n",
    "            # If update fail, flag to stop entire optimization\n",
    "            self.failed_GP = True\n",
    "        \n",
    "        # Update batch information \n",
    "        self.batch_X = X_next\n",
    "        self.batch_Y = Y_next\n",
    "        self.batch_C = C_next\n",
    "            \n",
    "        # Update best value\n",
    "        # Find the best value among the candidates\n",
    "        best_id = get_best_index_for_batch(n_tr=self.tr_number, Y=self.Y, C=self.C)\n",
    "            \n",
    "        # Update success and failure counters for trust region update\n",
    "        # If attribute 'best_X' does not exist, DoE was just evaluated -> no update on counters\n",
    "        if hasattr(self, 'best_X'):\n",
    "            if (self.C[best_id] <= 0).all():\n",
    "                # At least one new candidate is feasible\n",
    "                if (self.Y[best_id] > self.best_Y).any() or (self.best_C > 0).any():\n",
    "                    self.success_counter += 1\n",
    "                    self.failure_counter = 0                \n",
    "                else:\n",
    "                    self.success_counter = 0\n",
    "                    self.failure_counter += 1\n",
    "            else:\n",
    "                # No new candidate is feasible\n",
    "                total_violation_next = self.C[best_id].clamp(min=0).sum(dim=-1)\n",
    "                total_violation_center = self.best_C.clamp(min=0).sum(dim=-1)\n",
    "                if total_violation_next < total_violation_center:\n",
    "                    self.success_counter += 1\n",
    "                    self.failure_counter = 0\n",
    "                else:\n",
    "                    self.success_counter = 0\n",
    "                    self.failure_counter += 1\n",
    "        \n",
    "        # Update best values\n",
    "        self.best_X = self.X[best_id]\n",
    "        self.best_Y = self.Y[best_id]\n",
    "        self.best_C = self.C[best_id]\n",
    "        \n",
    "        # Update iteration counter\n",
    "        self.it_counter += 1\n",
    "        \n",
    "    def reset_status(self,\n",
    "                     **tkwargs):\n",
    "        '''Function to reset the status for the restart'''\n",
    "        \n",
    "        # Reset trust regions size\n",
    "        self.tr_ub: float = torch.ones((self.tr_number, self.dim), **tkwargs)      # Upper bounds of trust region\n",
    "        self.tr_lb: float = torch.zeros((self.tr_number, self.dim), **tkwargs)     # Lower bounds of trust region\n",
    "        self.tr_vol: float = torch.prod(self.tr_ub - self.tr_lb, dim=1)            # Volume of trust region\n",
    "        self.radius: float = 1.0                                                   # Percentage around which the trust region is built\n",
    "        self.radius_min: float = 0.5**7                                            # Minimum percentage for trust region\n",
    "\n",
    "        # Reset counters to change trust region size \n",
    "        self.failure_counter: int = 0    # Counter of failure points to asses how algorithm is going\n",
    "        self.success_counter: int = 0    # Counter of success points to asses how algorithm is going\n",
    "        \n",
    "        # Reset restart criteria trigger\n",
    "        self.restart_trigger: bool = False      # Trigger to restart optimization\n",
    "        self.failed_GP: bool = False            # Reset GPR failure trigger\n",
    "        \n",
    "        # Delete tensors with samples for training GPRs\n",
    "        if hasattr(self, 'X'):\n",
    "            del self.X\n",
    "            del self.Y\n",
    "            del self.C\n",
    "        \n",
    "        # Delete tensors with best value so far\n",
    "        if hasattr(self, 'best_X'):\n",
    "            del self.best_X\n",
    "            del self.best_Y\n",
    "            del self.best_C\n",
    "        \n",
    "        # Clear GPU memory\n",
    "        if tkwargs[\"device\"] == \"cuda\":\n",
    "            torch.cuda.empty_cache()  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58fee9fb-ca01-4024-8859-1fc3d9d4aea4",
   "metadata": {},
   "source": [
    "### Define trust region\n",
    "\n",
    "Define a set of functions to evaluate the trust region. First sample according to a Multinormal distribution the GPR surrogates (both objective and constraints). Rank the samples according to both the objective and violation estimation. Take the top $10\\%$ of the samples according to the rank. The trust region is defined as a hyperbox enclosing the picked samples. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d36cf01-c5be-44ee-96bb-12564225bd7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def multivariate_circular(centre,       # Centre of the multivariate distribution\n",
    "                          radius,       # Radius of the multivariate distribution\n",
    "                          n_samples,    # Number of samples to evaluate\n",
    "                          lb = None,    # Domain lower bound\n",
    "                          ub = None,    # Domain upper bound\n",
    "                          **tkwargs):\n",
    "    '''Function to generate multivariate distribution of given radius and centre within a given domain.'''\n",
    "    # Dimension of the design domain\n",
    "    dim = centre.shape[0]\n",
    "    \n",
    "    # Generate a multivariate normal distribution centered at 0\n",
    "    multivariate_normal = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(dim, **tkwargs), 0.025*torch.eye(dim, **tkwargs))\n",
    "    \n",
    "    #  Draw samples torch.distributions.multivariate_normal import MultivariateNormal\n",
    "    samples = multivariate_normal.sample(sample_shape=torch.Size([n_samples]))\n",
    "    \n",
    "    # Normalize each sample to have unit norm, then scale by the radius\n",
    "    norms = torch.norm(samples, dim=1, keepdim=True)  # Euclidean norms\n",
    "    normalized_samples = samples / norms  # Normalize to unit hypersphere\n",
    "    scaled_samples = normalized_samples * torch.rand(n_samples, 1, **tkwargs) * radius  # Scale by random factor within radius\n",
    "    \n",
    "    # Translate samples to be centered at centre\n",
    "    samples = scaled_samples + centre\n",
    "    \n",
    "    \n",
    "    # Trim samples outside domain\n",
    "    for dim in range(len(lb)):\n",
    "        samples = samples[torch.where(samples[:,dim]>=lb[dim])]\n",
    "        samples = samples[torch.where(samples[:,dim]<=ub[dim])]\n",
    "    \n",
    "    return samples\n",
    "\n",
    "def multinormal_radius(state,                # FuRBO state\n",
    "                       percentage = 0.1,     # Percentage to define trust region (default 10%)\n",
    "                       **tkwargs):\n",
    "    '''Function to sample Multinormal Distribution of GPRs and define trust region'''\n",
    "    # Update the trust regions based on the feasible region\n",
    "    n_samples = 1000 * state.dim\n",
    "    lb = torch.zeros(state.dim, **tkwargs)\n",
    "    ub = torch.ones(state.dim, **tkwargs)\n",
    "    \n",
    "    # Update radius dimension\n",
    "    if state.success_counter == state.success_tolerance:  # Expand trust region\n",
    "        state.radius = min(2.0 * state.radius, 1.0)\n",
    "        state.success_counter = 0\n",
    "    elif state.failure_counter == state.failure_tolerance:  # Shrink trust region\n",
    "        state.radius /= 2.0\n",
    "        state.failure_counter = 0\n",
    "    \n",
    "    for ind, x_candidate in enumerate(state.best_X):\n",
    "        # Generate the samples to evaluathe the feasible area on\n",
    "        radius = state.radius\n",
    "        samples = multivariate_circular(x_candidate, radius, n_samples, lb=lb, ub=ub, **tkwargs)\n",
    "    \n",
    "        # Evaluate samples on the models of the objective -> yy Tensor\n",
    "        state.Y_model.eval()\n",
    "        with torch.no_grad():\n",
    "            posterior = state.Y_model.posterior(samples)\n",
    "            samples_yy = posterior.mean.squeeze()\n",
    "        \n",
    "        # Evaluate samples on the models of the constraints -> yy Tensor\n",
    "        state.C_model.eval()\n",
    "        with torch.no_grad():\n",
    "            posterior = state.C_model.posterior(samples)\n",
    "            samples_cc = posterior.mean\n",
    "        \n",
    "        # Combine the constraints values\n",
    "            # Normalize\n",
    "        samples_cc /= torch.abs(samples_cc).max(dim=0).values\n",
    "        samples_cc = torch.max(samples_cc, dim=1).values\n",
    "        \n",
    "        # Take the best X% of the drawn samples to define the trust region\n",
    "        n_samples_tr = int(n_samples * percentage)\n",
    "        \n",
    "        # Order the samples for feasibility and for best objective\n",
    "        if torch.any(samples_cc < 0):\n",
    "            \n",
    "            feasible_samples_id = torch.where(samples_cc <= 0)[0]\n",
    "            infeasible_samples_id = torch.where(samples_cc > 0)[0]\n",
    "            \n",
    "            feasible_cc = -1 * samples_yy[feasible_samples_id]\n",
    "            infeasible_cc = samples_cc[infeasible_samples_id]\n",
    "            \n",
    "            feasible_sorted, feasible_sorted_id = torch.sort(feasible_cc)\n",
    "            infeasible_sorted, infeasible_sorted_id = torch.sort(infeasible_cc)\n",
    "            \n",
    "            original_feasible_sorted_indices = feasible_samples_id[feasible_sorted_id]\n",
    "            original_infeasible_sorted_indices = infeasible_samples_id[infeasible_sorted_id]\n",
    "            \n",
    "            top_indices = torch.cat((original_feasible_sorted_indices, original_infeasible_sorted_indices))[:n_samples_tr]\n",
    "        \n",
    "        else:\n",
    "            \n",
    "            if n_samples_tr > len(samples_cc):\n",
    "                n_samples_tr = len(samples_cc)\n",
    "                \n",
    "            if n_samples_tr < 4:\n",
    "                n_samples_tr = 4\n",
    "                \n",
    "            top_values, top_indices = torch.topk(samples_cc, n_samples_tr, largest=False)\n",
    "        \n",
    "        # Set the box around the selected samples\n",
    "        state.tr_lb[ind] = torch.min(samples[top_indices], dim=0).values\n",
    "        state.tr_ub[ind] = torch.max(samples[top_indices], dim=0).values\n",
    "        \n",
    "        # Update volume of trust region\n",
    "        state.tr_vol[ind] = torch.prod(state.tr_ub[ind] - state.tr_lb[ind])\n",
    "        \n",
    "    # return updated status with new trust regions\n",
    "    return state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f49690e5-6505-47df-89f7-1a56f9b087b2",
   "metadata": {},
   "source": [
    "### Sampling strategies\r\n",
    "\r\n",
    "Define a function to generate an initial experimental design using Sobol sampling strategy, similarly to SCBO (https://botorch.org/docs/tutorials/scalable_constrained_bo/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0116ca79-7555-4da3-bfd4-69941926eb11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_initial_points_sobol(state,\n",
    "                             **tkwargs):\n",
    "    '''Function to generate the initial experimental design'''\n",
    "    X_init = state.sobol.draw(n=state.n_init).to(**tkwargs)\n",
    "    return X_init"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55726979-92f6-488f-869c-2fa6140f6b85",
   "metadata": {},
   "source": [
    "Define a function to identify the best next candidate point, similar to SCBO(https://botorch.org/docs/tutorials/scalable_constrained_bo/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d969ea7-2f1e-4433-b3b4-53413546c2f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch.generation.sampling import ConstrainedMaxPosteriorSampling\n",
    "\n",
    "def generate_batch_thompson_sampling(state,\n",
    "                                     n_candidates,\n",
    "                                     **tkwargs):\n",
    "    '''Function to find net candidate optimum'''\n",
    "    assert state.X.min() >= 0.0 and state.X.max() <= 1.0 and torch.all(torch.isfinite(state.Y))\n",
    "\n",
    "    # Initialize tensor with samples to evaluate\n",
    "    X_next = torch.ones((state.batch_size*state.tr_number, state.dim), **tkwargs)\n",
    "    \n",
    "    # Iterate over the several trust regions\n",
    "    for i in range(state.tr_number):\n",
    "        tr_lb = state.tr_lb[i]\n",
    "        tr_ub = state.tr_ub[i]\n",
    "\n",
    "        # Thompson Sampling w/ Constraints (like SCBO)\n",
    "        pert = state.sobol.draw(n_candidates).to(**tkwargs)\n",
    "        pert = tr_lb + (tr_ub - tr_lb) * pert\n",
    "\n",
    "        # Create a perturbation mask\n",
    "        prob_perturb = min(20.0 / state.dim, 1.0)\n",
    "        mask = torch.rand(n_candidates, state.dim, **tkwargs) <= prob_perturb\n",
    "        ind = torch.where(mask.sum(dim=1) == 0)[0]\n",
    "        mask[ind, torch.randint(0, state.dim - 1, size=(len(ind),), device=tkwargs['device'])] = 1\n",
    "\n",
    "        # Create candidate points from the perturbations and the mask\n",
    "        X_cand = state.best_X[i].expand(n_candidates, state.dim).clone()\n",
    "        X_cand[mask] = pert[mask]\n",
    "        \n",
    "        # Sample on the candidate points using Constrained Max Posterior Sampling\n",
    "        constrained_thompson_sampling = ConstrainedMaxPosteriorSampling(\n",
    "            model=state.Y_model, constraint_model=state.C_model, replacement=False\n",
    "            )\n",
    "        with torch.no_grad():\n",
    "            X_next[i*state.batch_size:i*state.batch_size+state.batch_size, :] = constrained_thompson_sampling(X_cand, num_samples=state.batch_size)\n",
    "        \n",
    "    return X_next"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "024e10f4-4de9-433b-8827-c58a7177784f",
   "metadata": {},
   "source": [
    "### Stopping criterion\r\n",
    "\r\n",
    "Define a function to detect when the maximum number of iterations is met."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "189aae72-f033-49db-bcc8-0e791774bd76",
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_iterations(state, n_iteration):\n",
    "    '''Function to evaluate if the maximum number of allowed iterations is reached.'''\n",
    "    if state.it_counter < n_iteration:\n",
    "        return False\n",
    "    return True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44f15711-8503-4226-8ee0-171f26f7b8f4",
   "metadata": {},
   "source": [
    "Detect when the GPR fitting process fails to stop the optimization.urve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d2b342-601b-4d07-9fe3-459f28ecead2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def failed_GP(state):\n",
    "    '''Function to evaluate if a GPR failed during the optimization.'''\n",
    "    if state.failed_GP:\n",
    "        print(\"GPR failed\")\n",
    "        return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a80c7b75-a62d-46c2-aa80-28f6f29501be",
   "metadata": {},
   "source": [
    "### Main optimization loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d52417-aaa5-40b6-bf07-12c074460283",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "        \n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "dtype = torch.double\n",
    "tkwargs = {\"device\": device, \"dtype\": dtype}\n",
    "\n",
    "# Initialize FuRBO\n",
    "obj = ack(dim = 2,\n",
    "          negate=True,\n",
    "          **tkwargs)\n",
    "cons = list([sum_(threshold = 0,\n",
    "                  lb = obj.lb,\n",
    "                  ub = obj.ub), \n",
    "             norm_(threshold = 0.5, \n",
    "                   lb = obj.lb, \n",
    "                   ub = obj.ub)])\n",
    "batch_size = int(1)#3 * obj.dim)\n",
    "n_init = int(10)# * obj.dim)\n",
    "n_iteration = int(10)# * obj.dim)\n",
    "tr_number = 1\n",
    "N_CANDIDATES = 2000\n",
    "\n",
    "# FuRBO state initialization\n",
    "FuRBO_status = Furbo_state_single(obj = obj,                        # Objective function\n",
    "                                  cons = cons,                      # Constraints function\n",
    "                                  batch_size = batch_size,          # Batch size of each iteration\n",
    "                                  n_init = n_init,                  # Number of initial points to evaluate\n",
    "                                  n_iteration = n_iteration,        # Number of iterations\n",
    "                                  tr_number = tr_number,            # number of Trust regions\n",
    "                                  **tkwargs)\n",
    "\n",
    "# Initiate lists to save samples over the restarts\n",
    "X_best, Y_best, C_best = [], [], []\n",
    "X_all, Y_all, C_all = [], [], []\n",
    "\n",
    "# Continue optimization the stopping criterions isn't triggered\n",
    "while not FuRBO_status.finish_trigger: \n",
    "    \n",
    "    # Reset status for restarting\n",
    "    FuRBO_status.reset_status(**tkwargs)\n",
    "    \n",
    "    # generate intial batch of X\n",
    "    X_next = get_initial_points(FuRBO_status, **tkwargs)\n",
    "    \n",
    "    # Reset and restart optimization\n",
    "    while not FuRBO_status.restart_trigger and not FuRBO_status.finish_trigger:\n",
    "                \n",
    "        # Evaluate current batch (samples in X_next)\n",
    "        Y_next = []\n",
    "        C_next = []\n",
    "        for x in X_next:\n",
    "            # Evaluate batch on obj ...\n",
    "            Y_next.append(FuRBO_status.obj.eval_(x))\n",
    "            # ... and constraints\n",
    "            C_next.append([c.eval_(x) for c in FuRBO_status.cons])\n",
    "               \n",
    "        # process vector for PyTorch\n",
    "        Y_next = torch.tensor(Y_next).unsqueeze(-1).to(**tkwargs)\n",
    "        C_next = torch.tensor(C_next).to(**tkwargs)\n",
    "                \n",
    "        # Update FuRBO status with newly evaluated batch\n",
    "        FuRBO_status.update(X_next, Y_next, C_next, **tkwargs)   \n",
    "                \n",
    "        # Printing current best\n",
    "        # If a feasible has been evaluated -> print current optimum (feasible sample with best objective value)\n",
    "        if (FuRBO_status.best_C <= 0).all():\n",
    "            best = FuRBO_status.best_Y.amax()\n",
    "            print(f\"{FuRBO_status.it_counter-1}) Best value: {best:.2e},\"\n",
    "                  f\" MND radius: {FuRBO_status.radius}\")\n",
    "        \n",
    "        # Else, if no feasible has been evaluated -> print smallest violation (the sample that violatest the least all constraints)\n",
    "        else:\n",
    "            violation = FuRBO_status.best_C.clamp(min=0).sum()\n",
    "            print(f\"{FuRBO_status.it_counter-1}) No feasible point yet! Smallest total violation: \"\n",
    "                  f\"{violation:.2e}, MND radius: {FuRBO_status.radius}\")\n",
    "            \n",
    "        # Update Trust regions\n",
    "        FuRBO_status = update_tr(FuRBO_status,\n",
    "                                 **tkwargs)\n",
    "                \n",
    "        # generate next batch to evaluate \n",
    "        X_next = generate_batch(FuRBO_status, N_CANDIDATES, **tkwargs)\n",
    "        \n",
    "        # Check if stopping criterion is met (budget exhausted and if GP failed)\n",
    "        FuRBO_status.finish_trigger = stopping_criterion(FuRBO_status, n_iteration) \n",
    "        \n",
    "        # Check if restart criterion is met\n",
    "        FuRBO_status.restart_trigger = (restart_criterion(FuRBO_status, FuRBO_status.radius_min)\n",
    "                                        or GP_restart_criterion(FuRBO_status))\n",
    "\n",
    "    # Save samples evaluated before resetting the status\n",
    "    X_all.append(FuRBO_status.X)\n",
    "    Y_all.append(FuRBO_status.Y)\n",
    "    C_all.append(FuRBO_status.C)\n",
    "\n",
    "    # Save best sample of this run\n",
    "    X_best.append(FuRBO_status.best_X)\n",
    "    Y_best.append(FuRBO_status.best_Y)\n",
    "    C_best.append(FuRBO_status.best_C)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3d28f6c-5940-4d56-a3b5-4635980db76f",
   "metadata": {},
   "source": [
    "### Printing result and plotting convergence curve\n",
    "\n",
    "Print the best-evaluated sample (over all restarts) and the objective value (if a feasible sample was found) or the smallest violation (if no feasible sample was found)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027a9ec9-930a-48d0-b481-ffe9e9ca0d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print best value found so far\n",
    "# Ri-elaborate for processing\n",
    "X_best = torch.stack(X_best).to(**tkwargs)\n",
    "Y_best = torch.stack(Y_best).to(**tkwargs)\n",
    "C_best = torch.stack(C_best).to(**tkwargs)\n",
    "\n",
    "# If a feasible has been evaluated -> print current optimum sample and yielded value\n",
    "if (C_best <= 0).any():\n",
    "    best = Y_best.amax()\n",
    "    bext = X_best[Y_best.argmax()]\n",
    "    print(\"Optimization finished \\n\"\n",
    "         f\"\\t Optimum: {best:.2e}, \\n\"\n",
    "         f\"\\t X: {bext}\")\n",
    "    \n",
    "# Else, if no feasible has been evaluated -> print sample with smallest violation and the violation value\n",
    "else:\n",
    "    violation = C_best.sum(dim=2).amin()\n",
    "    violaxion = X_best[C_best.sum(dim=2).argmin()]\n",
    "   \n",
    "    print(\"Optimization failed \\n\"\n",
    "         f\"\\t Smallest violation: {violation:.2e}, \\n\"\n",
    "         f\"\\t X: {violaxion}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db5ae248-3f04-4173-a98a-6c0738e38510",
   "metadata": {},
   "source": [
    "Plot the monotonic convergence curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97bf4485-cc1b-4422-a2da-268499904647",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Transform vectors with all samples to elaborate for plots\n",
    "X_all = torch.concatenate(X_all).to(**tkwargs)\n",
    "Y_all = torch.concatenate(Y_all).to(**tkwargs)\n",
    "C_all = torch.concatenate(C_all).to(**tkwargs)\n",
    "\n",
    "# Transform values and constraints to numpy\n",
    "Y_f = Y_all.cpu().numpy()\n",
    "C_f = np.amax(C_all.cpu().numpy(), axis=1)\n",
    "\n",
    "# Get infeasible values to worst value evaluated\n",
    "Y_f[np.where(C_f > 0)[0]] = np.amin(Y_f)\n",
    "\n",
    "# Extract a monotonic curve\n",
    "Y_f_monotonic = []\n",
    "for yy in Y_f:\n",
    "    if len(Y_f_monotonic) == 0:\n",
    "        Y_f_monotonic.append(yy)\n",
    "    else:\n",
    "        if yy > Y_f_monotonic[-1]:\n",
    "            Y_f_monotonic.append(yy)\n",
    "        else:\n",
    "            Y_f_monotonic.append(Y_f_monotonic[-1])\n",
    "\n",
    "# Exclude initial DoE and generate x-y values for plot\n",
    "y = np.array(Y_f_monotonic).reshape(-1)\n",
    "x = np.linspace(1, len(y), len(y))\n",
    "\n",
    "# Plotting convergence\n",
    "plt.plot(x, y, lw=3)\n",
    "\n",
    "# Plot optimum line\n",
    "plt.plot([0, np.amax(x)], [0, 0], '--k', lw=3)\n",
    "plt.ylabel(\"Function value\", fontsize=18)\n",
    "plt.xlabel(\"Number of evaluations\", fontsize=18)\n",
    "plt.title(\"10D Ackley with 2 outcome constraints\", fontsize=20),\n",
    "plt.xlim([0, len(y)])\n",
    "plt.grid(True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
