{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Thompson sampling for Bayesian Optimization with GPs\n\nIn this example we show how to implement Thompson sampling for Bayesian optimization with Gaussian processes.\nThe implementation is based on this tutorial: https://gdmarmerola.github.io/ts-for-bayesian-optim/\n\n<img src=\"file://../_static/img/examples/thompson_sampling.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nimport jax.random as random\nfrom jax.scipy import linalg\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.infer import SVI, Trace_ELBO\nfrom numpyro.infer.autoguide import AutoDelta\n\nnumpyro.enable_x64()\n\n\n# the function to be minimized. At y=0 to get a 1D cut at the origin\ndef ackley_1d(x, y=0):\n    out = (\n        -20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x**2 + y**2)))\n        - jnp.exp(0.5 * (jnp.cos(2 * jnp.pi * x) + jnp.cos(2 * jnp.pi * y)))\n        + jnp.e\n        + 20\n    )\n    return out\n\n\n# matern kernel with nu = 5/2\ndef matern52_kernel(X, Z, var=1.0, length=0.5, jitter=1.0e-6):\n    d = jnp.sqrt(0.5) * jnp.sqrt(jnp.power((X[:, None] - Z), 2.0)) / length\n    k = var * (1 + d + (d**2) / 3) * jnp.exp(-d)\n    if jitter:\n        # we are assuming a noise free process, but add a small jitter for numerical stability\n        k += jitter * jnp.eye(X.shape[0])\n    return k\n\n\ndef model(X, Y, kernel=matern52_kernel):\n    # set uninformative log-normal priors on our kernel hyperparameters\n    var = numpyro.sample(\"var\", dist.LogNormal(0.0, 1.0))\n    length = numpyro.sample(\"length\", dist.LogNormal(0.0, 1.0))\n\n    # compute kernel\n    k = kernel(X, X, var, length)\n\n    # sample Y according to the standard gaussian process formula\n    numpyro.sample(\n        \"Y\",\n        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),\n        obs=Y,\n    )\n\n\nclass GP:\n    def __init__(self, kernel=matern52_kernel):\n        self.kernel = kernel\n        self.kernel_params = None\n\n    def fit(self, X, Y, rng_key, n_step):\n        self.X_train = X\n\n        # store moments of training y (to normalize)\n        self.y_mean = jnp.mean(Y)\n        self.y_std = jnp.std(Y)\n\n        # normalize y\n        Y = (Y - self.y_mean) / self.y_std\n\n        # setup optimizer and SVI\n        optim = numpyro.optim.Adam(step_size=0.005, b1=0.5)\n\n        svi = SVI(\n            model,\n            guide=AutoDelta(model),\n            optim=optim,\n            loss=Trace_ELBO(),\n            X=X,\n            Y=Y,\n        )\n\n        params, _ = svi.run(rng_key, n_step)\n\n        # get kernel parameters from guide with proper names\n        self.kernel_params = svi.guide.median(params)\n\n        # store cholesky factor of prior covariance\n        self.L = linalg.cho_factor(self.kernel(X, X, **self.kernel_params))\n\n        # store inverted prior covariance multiplied by y\n        self.alpha = linalg.cho_solve(self.L, Y)\n\n        return self.kernel_params\n\n    # do GP prediction for a given set of hyperparameters. this makes use of the well-known\n    # formula for gaussian process predictions\n    def predict(self, X, return_std=False):\n        # compute kernels between train and test data, etc.\n        k_pp = self.kernel(X, X, **self.kernel_params)\n        k_pX = self.kernel(X, self.X_train, **self.kernel_params, jitter=0.0)\n\n        # compute posterior covariance\n        K = k_pp - k_pX @ linalg.cho_solve(self.L, k_pX.T)\n\n        # compute posterior mean\n        mean = k_pX @ self.alpha\n\n        # we return both the mean function and the standard deviation\n        if return_std:\n            return (\n                (mean * self.y_std) + self.y_mean,\n                jnp.sqrt(jnp.diag(K * self.y_std**2)),\n            )\n        else:\n            return (mean * self.y_std) + self.y_mean, K * self.y_std**2\n\n    def sample_y(self, rng_key, X):\n        # get posterior mean and covariance\n        y_mean, y_cov = self.predict(X)\n        # draw one sample\n        return jax.random.multivariate_normal(rng_key, mean=y_mean, cov=y_cov)\n\n\n# our TS-GP optimizer\nclass ThompsonSamplingGP:\n    \"\"\"Adapted to numpyro from https://gdmarmerola.github.io/ts-for-bayesian-optim/\"\"\"\n\n    # initialization\n    def __init__(\n        self, gp, n_random_draws, objective, x_bounds, grid_resolution=1000, seed=123\n    ):\n        # Gaussian Process\n        self.gp = gp\n\n        # number of random samples before starting the optimization\n        self.n_random_draws = n_random_draws\n\n        # the objective is the function we're trying to optimize\n        self.objective = objective\n\n        # the bounds tell us the interval of x we can work\n        self.bounds = x_bounds\n\n        # interval resolution is defined as how many points we will use to\n        # represent the posterior sample\n        # we also define the x grid\n        self.grid_resolution = grid_resolution\n        self.X_grid = np.linspace(self.bounds[0], self.bounds[1], self.grid_resolution)\n\n        # also initializing our design matrix and target variable\n        self.X = np.array([])\n        self.y = np.array([])\n\n        self.rng_key = random.PRNGKey(seed)\n\n    # fitting process\n    def fit(self, X, y, n_step):\n        self.rng_key, subkey = random.split(self.rng_key)\n        # fitting the GP\n        self.gp.fit(X, y, rng_key=subkey, n_step=n_step)\n\n        # return the fitted model\n        return self.gp\n\n    # choose the next Thompson sample\n    def choose_next_sample(self, n_step=2_000):\n        # if we do not have enough samples, sample randomly from bounds\n        if self.X.shape[0] < self.n_random_draws:\n            self.rng_key, subkey = random.split(self.rng_key)\n            next_sample = random.uniform(\n                subkey, minval=self.bounds[0], maxval=self.bounds[1], shape=(1,)\n            )\n\n            # define dummy values for sample, mean and std to avoid errors when returning them\n            posterior_sample = np.array([np.mean(self.y)] * self.grid_resolution)\n            posterior_mean = np.array([np.mean(self.y)] * self.grid_resolution)\n            posterior_std = np.array([0] * self.grid_resolution)\n\n        # if we do, we fit the GP and choose the next point based on the posterior draw minimum\n        else:\n            # 1. Fit the GP to the observations we have\n            self.gp = self.fit(self.X, self.y, n_step=n_step)\n\n            # 2. Draw one sample (a function) from the posterior\n            self.rng_key, subkey = random.split(self.rng_key)\n            posterior_sample = self.gp.sample_y(subkey, self.X_grid)\n\n            # 3. Choose next point as the optimum of the sample\n            which_min = np.argmin(posterior_sample)\n            next_sample = self.X_grid[which_min]\n\n            # let us also get the std from the posterior, for visualization purposes\n            posterior_mean, posterior_std = self.gp.predict(\n                self.X_grid, return_std=True\n            )\n\n        # let us observe the objective and append this new data to our X and y\n        next_observation = self.objective(next_sample)\n        self.X = np.append(self.X, next_sample)\n        self.y = np.append(self.y, next_observation)\n\n        # returning values of interest\n        return (\n            self.X,\n            self.y,\n            self.X_grid,\n            posterior_sample,\n            posterior_mean,\n            posterior_std,\n        )\n\n\ndef main(args):\n    gp = GP(kernel=matern52_kernel)\n    # do inference\n    thompson = ThompsonSamplingGP(\n        gp, n_random_draws=args.num_random, objective=ackley_1d, x_bounds=(-4, 4)\n    )\n\n    fig, axes = plt.subplots(\n        args.num_samples - args.num_random, 1, figsize=(6, 12), sharex=True, sharey=True\n    )\n    for i in range(args.num_samples):\n        (\n            X,\n            y,\n            X_grid,\n            posterior_sample,\n            posterior_mean,\n            posterior_std,\n        ) = thompson.choose_next_sample(\n            n_step=args.num_step,\n        )\n\n        if i >= args.num_random:\n            ax = axes[i - args.num_random]\n            # plot training data\n            ax.scatter(X, y, color=\"blue\", marker=\"o\", label=\"samples\")\n            ax.axvline(\n                X_grid[posterior_sample.argmin()],\n                color=\"blue\",\n                linestyle=\"--\",\n                label=\"next sample\",\n            )\n            ax.plot(X_grid, ackley_1d(X_grid), color=\"black\", linestyle=\"--\")\n            ax.plot(\n                X_grid,\n                posterior_sample,\n                color=\"red\",\n                linestyle=\"-\",\n                label=\"posterior sample\",\n            )\n            # plot 90% confidence level of predictions\n            ax.fill_between(\n                X_grid,\n                posterior_mean - posterior_std,\n                posterior_mean + posterior_std,\n                color=\"red\",\n                alpha=0.5,\n            )\n            ax.set_ylabel(\"Y\")\n            if i == args.num_samples - 1:\n                ax.set_xlabel(\"X\")\n\n    plt.legend(\n        loc=\"upper center\",\n        bbox_to_anchor=(0.5, -0.15),\n        fancybox=True,\n        shadow=True,\n        ncol=3,\n    )\n\n    fig.suptitle(\"Thompson sampling\")\n    fig.tight_layout()\n    plt.show()\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Thompson sampling example\")\n    parser.add_argument(\n        \"--num-random\", nargs=\"?\", default=2, type=int, help=\"number of random draws\"\n    )\n    parser.add_argument(\n        \"--num-samples\",\n        nargs=\"?\",\n        default=10,\n        type=int,\n        help=\"number of Thompson samples\",\n    )\n    parser.add_argument(\n        \"--num-step\",\n        nargs=\"?\",\n        default=2_000,\n        type=int,\n        help=\"number of steps for optimization\",\n    )\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}