{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The max value entropy search acquisition function\n",
    "\n",
    "Max-value entropy search (MES) acquisition function quantifies the information gain about the maximum of a black-box function by observing this black-box function $f$ at the candidate set $\\{\\textbf{x}\\}$ (see [1, 2]). BoTorch provides implementations of the MES acquisition function and its multi-fidelity (MF) version with support for trace observations. In this tutorial, we explain at a high level how the MES acquisition function works, its implementation in BoTorch and how to use the MES acquisition function to query the next point in the optimization process. \n",
    "\n",
    "In general, we recommend using [Ax](https://ax.dev) for a simple BO setup like this one, since this will simplify your setup (including the amount of code you need to write) considerably. You can use a custom BoTorch model and acquisition function in Ax, following the [Using BoTorch with Ax](./custom_botorch_model_in_ax) tutorial. To use the MES acquisition function, it is sufficient to add `\"botorch_acqf_class\": qMaxValueEntropy,` to `model_kwargs`. The linked tutorial shows how to use a custom BoTorch model. If you'd like to let Ax choose which model to use based on the properties of the search space, you can skip the `surrogate` argument in `model_kwargs`.\n",
    "\n",
    "### 1. MES acquisition function for $q=1$ with noisy observation\n",
    "For illustrative purposes, we focus in this section on the non-q-batch-mode case ($q=1$). We also assume that the evaluation of the black-box function is noisy. Let us first introduce some notation: \n",
    "+ $f^* = \\max_\\mathcal{X} (f(\\textbf{x}))$, the maximum of the black-box function $f(\\textbf{x})$ in the design space $\\mathcal{X}$\n",
    "+ $y = f(\\textbf{x}) + \\epsilon, \\epsilon \\sim N(0, \\sigma^2_\\epsilon)$, the noisy observation at the design point $\\textbf{x}$\n",
    "+ $h(Y) = \\mathbb{E}_Y[-\\log(p(y))] = -\\int_\\mathcal{Y} p(y)\\log p(y) dy$, the differential entropy of random variable $Y$ with support $\\mathcal{Y}$: the larger is $h(Y)$, the larger is the uncertainty of $Y$.\n",
    "+ $v(\\mathcal{D}) = -\\mathbb{E}_D[h(F^*\\mid\\mathcal{D})]$, the value of data set $\\mathcal{D}$, where $F^*$ denotes the function maximum (a random variable in our context of our model).\n",
    "\n",
    "\n",
    "The Max-value Entropy Search (MES) acquisition function at $\\textbf{x}$ after observing $\\mathcal{D}_t$ can be written as\n",
    "\\begin{align}\n",
    "    \\alpha_{\\text{MES}}(\\textbf{x}) \n",
    "    &= v(\\mathcal{D}_t\\cup \\{(\\textbf{x}, y)\\}) - v(\\mathcal{D}_t) \\\\\n",
    "    &= - \\mathbb{E}_Y[h(F^* \\mid \\mathcal{D}_t\\cup \\{(\\textbf{x}, Y)\\})] + h(F^*\\mid\\mathcal{D}_t) \\\\\n",
    "    &= - \\mathbb{E}_Y[h(F^* \\mid Y)] + h(F^*) \\\\\n",
    "    &= I(F^*; Y) \\\\\n",
    "    &= I(Y; F^*) \\quad \\text{(symmetry)} \\\\\n",
    "    &= - \\mathbb{E}_{F^*}[h(Y \\mid F^*)] + h(Y) \\\\    \n",
    "\\end{align}\n",
    ", which is the mutual information of random variables \n",
    "$F^*\\mid \\mathcal{D}_t$ and $Y \\mid \\textbf{x}, \\mathcal{D}_t$. \n",
    "Here $F^*$ follows the max value distribution conditioned on $\\mathcal{D}_t$, and $Y$ follows the GP posterior distribution with noise at $\\textbf{x}$ after observing $\\mathcal{D}_t$.\n",
    "\n",
    "Rewrite the above formula as\n",
    "\\begin{align}\n",
    "    \\alpha_{\\text{MES}}(\\textbf{x}) &= - H_1 + H_0, \\\\\n",
    "    H_0 &= h(Y) = \\log \\left(\\sqrt{2\\pi e (\\sigma_f^2 + \\sigma_\\epsilon^2)}\\right) \\\\\n",
    "    H_1 &= \\mathbb{E}_{F^*}[h(Y \\mid F^*)] \\\\\n",
    "        &\\simeq \\frac{1}{\\left|\\mathcal{F}_*\\right|} \\Sigma_{\\mathcal{F}_*} h(Y\\mid f^*))\n",
    "\\end{align}\n",
    ", where $\\mathcal{F}_*$ are the max value samples drawn from the posterior after observing $\\mathcal{D}_t$. Without noise, $p(y \\mid f^*) = p(f \\mid f \\leq f^*)$ is a truncated normal distribution with an analytic expression for its entropy. With noise, $Y\\mid F\\leq f^*$ is not a truncated normal distribution anymore. The question is then how to compute $h(Y\\mid f^*)$ or equivalently $p(y\\mid f \\leq f^*)$?\n",
    "\n",
    "\n",
    "Using Bayes' theorem, \n",
    "\\begin{align}\n",
    "    p(y\\mid f \\leq f^*) = \\frac{P(f \\leq f^* \\mid y) p(y)}{P(f \\leq f^* )}\n",
    "\\end{align}\n",
    ", where \n",
    "+ $p(y)$ is the posterior probability density function (PDF) with observation noise.\n",
    "+ $P(f \\leq f^*)$ is the posterior cummulative distribution function (CDF) without observation noise, given any $f^*$.\n",
    "\n",
    "We also know from the GP predictive distribution\n",
    "\\begin{align}\n",
    "    \\begin{bmatrix}\n",
    "        y \\\\ f\n",
    "    \\end{bmatrix}\n",
    "    \\sim \\mathcal{N} \\left(\n",
    "    \\begin{bmatrix}\n",
    "        \\mu \\\\ \\mu\n",
    "    \\end{bmatrix} , \n",
    "    \\begin{bmatrix}\n",
    "        \\sigma_f^2 + \\sigma_\\epsilon^2 & \\sigma_f^2 \\\\ \n",
    "        \\sigma_f^2 & \\sigma_f^2\n",
    "    \\end{bmatrix}\n",
    "    \\right).\n",
    "\\end{align}\n",
    "So\n",
    "\\begin{align}\n",
    "    f \\mid y \\sim \\mathcal{N} (u, s^2)\n",
    "\\end{align}\n",
    ", where\n",
    "\\begin{align}\n",
    "    u   &= \\frac{\\sigma_f^2(y-\\mu)}{\\sigma_f^2 + \\sigma_\\epsilon^2} + \\mu \\\\\n",
    "    s^2 &= \\sigma_f^2 - \\frac{(\\sigma_f^2)^2}{\\sigma_f^2 + \\sigma_\\epsilon^2}\n",
    "        = \\frac{\\sigma_f^2\\sigma_\\epsilon^2}{\\sigma_f^2 + \\sigma_\\epsilon^2}\n",
    "\\end{align}\n",
    "Thus, $P(f \\leq f^* \\mid y)$ is the CDF of above Gaussian. \n",
    "\n",
    "Finally, given $f^*$, we have  \n",
    "\\begin{align}\n",
    "    h(Y \\mid f^*) \n",
    "    &= -\\int_\\mathcal{Y} p(y \\mid f^*)\\log(p(y \\mid f^*)) dy\\\\\n",
    "    &= -\\int_\\mathcal{Y} Zp(y)\\log(Zp(y)) dy \\\\\n",
    "    &\\simeq -\\frac{1}{\\left|\\mathcal{Y}\\right|} \\Sigma_{\\mathcal{Y}} Z\\log(Zp(y)), \\\\\n",
    "    Z &= \\frac{P(f \\leq f^* \\mid y)}{P(f \\leq f^* )}\n",
    "\\end{align}\n",
    ", where $Z$ is the ratio of two CDFs and $\\mathcal{Y}$ is the samples drawn from the posterior distribution with noisy observation. The above formulation for noisy MES is inspired from the MF-MES formulation proposed by Takeno _et. al_ [1], which is essentially the same as what is outlined above. \n",
    "\n",
    "Putting all together, \n",
    "\\begin{align}\n",
    "    \\alpha_{\\text{MES}}(\\textbf{x}) \n",
    "    &= H_0 - H_1 \\\\\n",
    "    &\\simeq H_0 - H_1^{MC}\\\\\n",
    "    &= \\log \\left(\\sqrt{2\\pi e (\\sigma_f^2 + \\sigma_\\epsilon^2)}\\right) + \\frac{1}{\\left|\\mathcal{F}^*\\right|} \\Sigma_{\\mathcal{F}^*} \\frac{1}{\\left|\\mathcal{Y}\\right|} \\Sigma_{\\mathcal{Y}} (Z\\log Z + Z\\log p(y))\n",
    "\\end{align}\n",
    "\n",
    "The next design point to query is chosen as the point that maximizes this aquisition function, _i. e._, \n",
    "\\begin{align}\n",
    "    \\textbf{x}_{\\text{next}} = \\max_{\\textbf{x} \\in \\mathcal{X}} \\alpha_{\\text{MES}}(\\textbf{x})\n",
    "\\end{align}\n",
    "\n",
    "The implementation in Botorch basically follows the above formulation for both non-MF and MF cases. One difference is that, in order to reduce the variance of the MC estimator for $H_1$, we apply also regression adjustment to get an estimation of $H_1$, \n",
    "\\begin{align}\n",
    "    \\widehat{H}_1 &= H_1^{MC} - \\beta (H_0^{MC} - H_0) \n",
    "\\end{align}\n",
    ", where\n",
    "\\begin{align}\n",
    "    H_0^{MC} &= - \\frac{1}{\\left|\\mathcal{Y}\\right|} \\Sigma_{\\mathcal{Y}} \\log p(y) \\\\\n",
    "    \\beta &= \\frac{Cov(h_1, h_0)}{\\sqrt{Var(h_1)Var(h_0)}} \\\\\n",
    "    h_0 &= -\\log p(y) \\\\\n",
    "    h_1 &= -Z\\log(Zp(y)) \\\\\n",
    "\\end{align}\n",
    "This turns out to reduce the variance of the acquisition value by a significant factor, especially when the acquisition value is small, hence making the algorithm numerically more stable. \n",
    "\n",
    "For the case of $q > 1$, joint optimization becomes difficult, since the q-batch-mode MES acquisiton function becomes not tractable due to the multivariate normal CDF functions in $Z$. Instead, the MES acquisition optimization is solved sequentially and using fantasies, _i. e._, we generate one point each time and when we try to generate the $i$-th point, we condition the models on the $i-1$ points generated prior to this (using the $i-1$ points as fantasies).  \n",
    "\n",
    "<br>\n",
    "__References__\n",
    "\n",
    "[1] [Takeno, S., et al., _Multi-fidelity Bayesian Optimization with Max-value Entropy Search._  arXiv:1901.08275v1, 2019](https://arxiv.org/abs/1901.08275)\n",
    "\n",
    "[2] [Wang, Z., Jegelka, S., _Max-value Entropy Search for Efficient Bayesian Optimization._ arXiv:1703.01968v3, 2018](https://arxiv.org/abs/1703.01968)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### 2. Setting up a toy model\n",
    "We will fit a standard SingleTaskGP model on noisy observations of the synthetic 2D Branin function on the hypercube $[-5,10]\\times [0, 15]$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "\n",
    "from botorch.test_functions import Branin\n",
    "from botorch.fit import fit_gpytorch_mll\n",
    "from botorch.models import SingleTaskGP\n",
    "from botorch.utils.transforms import standardize, normalize\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "\n",
    "torch.manual_seed(7)\n",
    "\n",
    "bounds = torch.tensor(Branin._bounds).T\n",
    "train_X = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(10, 2)\n",
    "train_Y = Branin(negate=True)(train_X).unsqueeze(-1)\n",
    "\n",
    "train_X = normalize(train_X, bounds=bounds)\n",
    "train_Y = standardize(train_Y + 0.05 * torch.randn_like(train_Y))\n",
    "\n",
    "model = SingleTaskGP(train_X, train_Y)\n",
    "mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
    "fit_gpytorch_mll(mll);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Defining the MES acquisition function\n",
    "\n",
    "The `qMaxValueEntropy` acquisition function is a subclass of `MCAcquisitionFunction` and supports pending points `X_pending`. Required arguments for the constructor are `model` and `candidate_set` (the discretized candidate points in the design space that will be used to draw max value samples). There are also other optional parameters, such as number of max value samples $\\mathcal{F^*}$, number of $\\mathcal{Y}$ samples and number of fantasies (in case of $q>1$). Two different sampling algorithms are supported for the max value samples: the discretized Thompson sampling and the Gumbel sampling introduced in [2]. Gumbel sampling is the default choice in the acquisition function. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy\n",
    "\n",
    "candidate_set = torch.rand(\n",
    "    1000, bounds.size(1), device=bounds.device, dtype=bounds.dtype\n",
    ")\n",
    "candidate_set = bounds[0] + (bounds[1] - bounds[0]) * candidate_set\n",
    "qMES = qMaxValueEntropy(model, candidate_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Optimizing the MES acquisition function to get the next candidate points\n",
    "In order to obtain the next candidate point(s) to query, we need to optimize the acquisition function over the design space. For $q=1$ case, we can simply call the `optimize_acqf` function in the library. At $q>1$, due to the intractability of the aquisition function in this case, we need to use either sequential or cyclic optimization (multiple cycles of sequential optimization). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[1.5350, 0.0758]]), tensor(0.0121))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from botorch.optim import optimize_acqf\n",
    "\n",
    "# for q = 1\n",
    "candidates, acq_value = optimize_acqf(\n",
    "    acq_function=qMES,\n",
    "    bounds=bounds,\n",
    "    q=1,\n",
    "    num_restarts=10,\n",
    "    raw_samples=512,\n",
    ")\n",
    "candidates, acq_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.3238,  0.6565],\n",
       "         [ 1.5349,  0.0748]]), tensor([0.0135, 0.0065]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# for q = 2, sequential optimization\n",
    "candidates_q2, acq_value_q2 = optimize_acqf(\n",
    "    acq_function=qMES,\n",
    "    bounds=bounds,\n",
    "    q=2,\n",
    "    num_restarts=10,\n",
    "    raw_samples=512,\n",
    "    sequential=True,\n",
    ")\n",
    "candidates_q2, acq_value_q2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.3236,  0.6563],\n",
       "         [ 1.5326,  0.0732]]), tensor([0.0101, 0.0064]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from botorch.optim import optimize_acqf_cyclic\n",
    "\n",
    "# for q = 2, cyclic optimization\n",
    "candidates_q2_cyclic, acq_value_q2_cyclic = optimize_acqf_cyclic(\n",
    "    acq_function=qMES,\n",
    "    bounds=bounds,\n",
    "    q=2,\n",
    "    num_restarts=10,\n",
    "    raw_samples=512,\n",
    "    cyclic_options={\"maxiter\": 2},\n",
    ")\n",
    "candidates_q2_cyclic, acq_value_q2_cyclic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The use of the `qMultiFidelityMaxValueEntropy` acquisition function is very similar to `qMaxValueEntropy`, but requires additional optional arguments related to the fidelity and cost models. We will provide more details on the MF-MES acquisition function in a separate tutorial.  "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
