{
 "metadata": {
  "dataExplorerConfig": {},
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "captumWidgetMessage": {},
  "outputWidgetContext": {}
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "7b4aef48-3c3f-4bec-8906-2844d31556ba",
    "showInput": false
   },
   "source": [
    "## Analytic and MC-based Expected Improvement (EI) acquisition\n",
    "\n",
    "In this tutorial, we compare the analytic and MC-based EI acquisition functions and show both `scipy`- and `torch`-based optimizers for optimizing the acquisition. This tutorial highlights the modularity of botorch and the ability to easily try different acquisition functions and accompanying optimization algorithms on the same fitted model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "6d70e250-da2b-460a-8989-5c048f7a6ce8",
    "showInput": false
   },
   "source": [
    "### Comparison of analytic and MC-based EI"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "f678d607-be4c-4f37-aed5-3597158432ce",
    "collapsed": false,
    "requestMsgId": "0aae9d3f-d796-4a18-a4aa-b015b5b582ac",
    "customOutput": null,
    "executionStartTime": 1668649205799,
    "executionStopTime": 1668649205822
   },
   "source": [
    "import torch\n",
    "\n",
    "from botorch.fit import fit_gpytorch_mll\n",
    "from botorch.models import SingleTaskGP\n",
    "from botorch.test_functions import Hartmann\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "\n",
    "neg_hartmann6 = Hartmann(dim=6, negate=True)"
   ],
   "execution_count": 16,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "332a3aeb-187f-4265-b52f-4ca7bdcf5e4a",
    "showInput": false
   },
   "source": [
    "First, we generate some random data and fit a SingleTaskGP for a 6-dimensional synthetic test function 'Hartmann6'."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "a7724f86-8b67-4f70-bf57-f0da79b88f52",
    "collapsed": false,
    "requestMsgId": "25794582-0506-4e89-a112-ba362b7c7e59",
    "customOutput": null,
    "executionStartTime": 1668649205895,
    "executionStopTime": 1668649206067
   },
   "source": [
    "train_x = torch.rand(10, 6)\n",
    "train_obj = neg_hartmann6(train_x).unsqueeze(-1)\n",
    "model = SingleTaskGP(train_X=train_x, train_Y=train_obj)\n",
    "mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
    "fit_gpytorch_mll(mll);"
   ],
   "execution_count": 17,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "576d19e8-2164-4ae9-8416-4ca205d35a77",
    "showInput": false
   },
   "source": [
    "Initialize an analytic EI acquisition function on the fitted model.\n",
    ""
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "1d611db5-29ee-4e9b-8fea-fe9d274b5222",
    "collapsed": false,
    "requestMsgId": "ba6f0917-f622-486a-818c-bb9ba4580ea5",
    "customOutput": null,
    "executionStartTime": 1668649206124,
    "executionStopTime": 1668649206138
   },
   "source": [
    "from botorch.acquisition import ExpectedImprovement\n",
    "\n",
    "best_value = train_obj.max()\n",
    "EI = ExpectedImprovement(model=model, best_f=best_value)"
   ],
   "execution_count": 18,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "e86567ac-e3f1-46ec-889a-6cc21f2fc918",
    "showInput": false
   },
   "source": [
    "Next, we optimize the analytic EI acquisition function using 50 random restarts chosen from 100 initial raw samples."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "dc5613c6-2f99-4193-8956-6e710fee5fa2",
    "collapsed": false,
    "requestMsgId": "3df2fc12-7f4c-4abb-b1d2-90bb3b8bf05c",
    "customOutput": null,
    "executionStartTime": 1668649206218,
    "executionStopTime": 1668649206938
   },
   "source": [
    "from botorch.optim import optimize_acqf\n",
    "\n",
    "new_point_analytic, _ = optimize_acqf(\n",
    "    acq_function=EI,\n",
    "    bounds=torch.tensor([[0.0] * 6, [1.0] * 6]),\n",
    "    q=1,\n",
    "    num_restarts=20,\n",
    "    raw_samples=100,\n",
    "    options={},\n",
    ")"
   ],
   "execution_count": 19,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "76fb19a3-c2c2-451a-8c0b-50cb14c55460",
    "collapsed": false,
    "requestMsgId": "a5cbada9-0b7c-41a2-934f-10d9bbe2e316",
    "customOutput": null,
    "executionStartTime": 1668649206992,
    "executionStopTime": 1668649207011
   },
   "source": [
    "new_point_analytic"
   ],
   "execution_count": 20,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor([[0.4730, 0.0836, 0.8247, 0.5628, 0.2964, 0.6131]])"
     },
     "metadata": {
      "bento_obj_id": "140510701845616"
     },
     "execution_count": 20
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "99862767-916b-4de0-881a-2eafc819f356",
    "showInput": false
   },
   "source": [
    "Now, let's swap out the analytic acquisition function and replace it with an MC version. Note that we are in the `q = 1` case; for `q > 1`, an analytic version does not exist."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "aaf04cba-3716-4fbd-8baa-2c75dd068860",
    "collapsed": false,
    "requestMsgId": "0e7691f2-34c7-43df-a247-7f7ba95220f1",
    "customOutput": null,
    "executionStartTime": 1668649207083,
    "executionStopTime": 1668649207929
   },
   "source": [
    "from botorch.acquisition import qExpectedImprovement\n",
    "from botorch.sampling import SobolQMCNormalSampler\n",
    "\n",
    "\n",
    "sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]), seed=0)\n",
    "MC_EI = qExpectedImprovement(model, best_f=best_value, sampler=sampler)\n",
    "torch.manual_seed(seed=0)  # to keep the restart conditions the same\n",
    "new_point_mc, _ = optimize_acqf(\n",
    "    acq_function=MC_EI,\n",
    "    bounds=torch.tensor([[0.0] * 6, [1.0] * 6]),\n",
    "    q=1,\n",
    "    num_restarts=20,\n",
    "    raw_samples=100,\n",
    "    options={},\n",
    ")"
   ],
   "execution_count": 21,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "73ffa9ea-3cff-46eb-91ea-b2f75fdb07f2",
    "customOutput": null,
    "collapsed": false,
    "requestMsgId": "b780cff4-6e90-4e39-8558-b04136e71e94",
    "executionStartTime": 1668649207976,
    "executionStopTime": 1668649207989
   },
   "source": [
    "new_point_mc"
   ],
   "execution_count": 22,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor([[0.4730, 0.0835, 0.8248, 0.5627, 0.2963, 0.6130]])"
     },
     "metadata": {
      "bento_obj_id": "140510701845696"
     },
     "execution_count": 22
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "41f37b1a-df0a-4bb9-8996-fdddd3fe298c",
    "showInput": false
   },
   "source": [
    "Check that the two generated points are close."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "c5c20ba9-82af-4d07-832f-86ede74f8959",
    "customOutput": null,
    "collapsed": false,
    "requestMsgId": "0b3db1ad-6ddb-4f86-9767-e0a486914b33",
    "executionStartTime": 1668649208035,
    "executionStopTime": 1668649208043
   },
   "source": [
    "torch.norm(new_point_mc - new_point_analytic)"
   ],
   "execution_count": 23,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor(0.0002)"
     },
     "metadata": {
      "bento_obj_id": "140510702063760"
     },
     "execution_count": 23
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "f45c60db-1318-405d-b95a-6b2723bc5f46",
    "showInput": false
   },
   "source": [
    "### Using a torch optimizer on a stochastic acquisition function\n",
    "\n",
    "We could also optimize using a `torch` optimizer. This is particularly useful for the case of a stochastic acquisition function, which we can obtain by using a `StochasticSampler`. First, we illustrate the usage of `torch.optim.Adam`. In the code snippet below, `gen_batch_initial_candidates` uses a heuristic to select a set of restart locations, `gen_candidates_torch` is a wrapper to the `torch` optimizer for maximizing the acquisition value, and `get_best_candidates` finds the best result amongst the random restarts.\n",
    "\n",
    "Under the hood, `gen_candidates_torch` uses a convergence criterion based on exponential moving averages of the loss. "
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "434e0156-94e1-4aca-a0f6-00f6ffb3a2b0",
    "customOutput": null,
    "collapsed": false,
    "requestMsgId": "3e15f8c7-ee79-4c95-8583-3bd376229a39",
    "executionStartTime": 1668649208095,
    "executionStopTime": 1668649208252
   },
   "source": [
    "from botorch.sampling.stochastic_samplers import StochasticSampler\n",
    "from botorch.generation import get_best_candidates, gen_candidates_torch\n",
    "from botorch.optim import gen_batch_initial_conditions\n",
    "\n",
    "resampler = StochasticSampler(sample_shape=torch.Size([512]))\n",
    "MC_EI_resample = qExpectedImprovement(model, best_f=best_value, sampler=resampler)\n",
    "bounds = torch.tensor([[0.0] * 6, [1.0] * 6])\n",
    "\n",
    "batch_initial_conditions = gen_batch_initial_conditions(\n",
    "    acq_function=MC_EI_resample,\n",
    "    bounds=bounds,\n",
    "    q=1,\n",
    "    num_restarts=20,\n",
    "    raw_samples=100,\n",
    ")\n",
    "batch_candidates, batch_acq_values = gen_candidates_torch(\n",
    "    initial_conditions=batch_initial_conditions,\n",
    "    acquisition_function=MC_EI_resample,\n",
    "    lower_bounds=bounds[0],\n",
    "    upper_bounds=bounds[1],\n",
    "    optimizer=torch.optim.Adam,\n",
    "    options={\"maxiter\": 500},\n",
    ")\n",
    "new_point_torch_Adam = get_best_candidates(\n",
    "    batch_candidates=batch_candidates, batch_values=batch_acq_values\n",
    ").detach()"
   ],
   "execution_count": 24,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "81c29b36-c663-47e1-8155-ad034c214f53",
    "customOutput": null,
    "collapsed": false,
    "requestMsgId": "aac6f703-e046-448a-8abe-1742befb9bf9",
    "executionStartTime": 1668649208304,
    "executionStopTime": 1668649208320
   },
   "source": [
    "new_point_torch_Adam"
   ],
   "execution_count": 25,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor([[0.4527, 0.1183, 0.8902, 0.5630, 0.3151, 0.5804]])"
     },
     "metadata": {
      "bento_obj_id": "140510701998384"
     },
     "execution_count": 25
    }
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "17fb0de0-3c5a-414e-9aba-b82710d166c0",
    "customOutput": null,
    "collapsed": false,
    "requestMsgId": "a13ce358-3ee6-43ad-9e3a-16181a8cdc1e",
    "executionStartTime": 1668649208364,
    "executionStopTime": 1668649208372
   },
   "source": [
    "torch.norm(new_point_torch_Adam - new_point_analytic)"
   ],
   "execution_count": 26,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor(0.0855)"
     },
     "metadata": {
      "bento_obj_id": "140510701610704"
     },
     "execution_count": 26
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "9cce59c6-18f5-4d61-ace0-a0d88cd4b8eb",
    "showInput": false
   },
   "source": [
    "By changing the `optimizer` parameter to `gen_candidates_torch`, we can also try `torch.optim.SGD`. Note that without the adaptive step size selection of Adam, basic SGD does worse job at optimizing without further manual tuning of the optimization parameters."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "dce34cee-fe31-476e-a20d-46237d5861e0",
    "customOutput": null,
    "collapsed": false,
    "requestMsgId": "47a5a0d7-3281-403d-910a-97054021b811",
    "executionStartTime": 1668649208422,
    "executionStopTime": 1668649208460
   },
   "source": [
    "batch_candidates, batch_acq_values = gen_candidates_torch(\n",
    "    initial_conditions=batch_initial_conditions,\n",
    "    acquisition_function=MC_EI_resample,\n",
    "    lower_bounds=bounds[0],\n",
    "    upper_bounds=bounds[1],\n",
    "    optimizer=torch.optim.SGD,\n",
    "    options={\"maxiter\": 500},\n",
    ")\n",
    "new_point_torch_SGD = get_best_candidates(\n",
    "    batch_candidates=batch_candidates, batch_values=batch_acq_values\n",
    ").detach()"
   ],
   "execution_count": 27,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "350e456d-0d1c-46dc-a618-0fbba9e0a158",
    "customOutput": null,
    "collapsed": false,
    "requestMsgId": "aa33d42e-c526-4117-88a7-aa3034d82886",
    "executionStartTime": 1668649208505,
    "executionStopTime": 1668649208523
   },
   "source": [
    "new_point_torch_SGD"
   ],
   "execution_count": 28,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor([[0.3566, 0.0410, 0.7926, 0.3118, 0.3758, 0.6110]])"
     },
     "metadata": {
      "bento_obj_id": "140510702066640"
     },
     "execution_count": 28
    }
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "originalKey": "e263cfc7-47a0-4b81-ab33-3aa16320c87e",
    "customOutput": null,
    "collapsed": false,
    "requestMsgId": "3c654fc0-ce64-43c7-a8bf-42935257008a",
    "executionStartTime": 1668649208566,
    "executionStopTime": 1668649208574
   },
   "source": [
    "torch.norm(new_point_torch_SGD - new_point_analytic)"
   ],
   "execution_count": 29,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor(0.2928)"
     },
     "metadata": {
      "bento_obj_id": "140510701611584"
     },
     "execution_count": 29
    }
   ]
  }
 ]
}
