{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Gaussian Process\n\nIn this example we show how to use NUTS to sample from the posterior\nover the hyperparameters of a gaussian process.\n\n<img src=\"file://../_static/img/examples/gp.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\nimport time\n\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nimport jax\nfrom jax import vmap\nimport jax.numpy as jnp\nimport jax.random as random\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.infer import (\n    MCMC,\n    NUTS,\n    init_to_feasible,\n    init_to_median,\n    init_to_sample,\n    init_to_uniform,\n    init_to_value,\n)\n\nmatplotlib.use(\"Agg\")  # noqa: E402\n\n\n# squared exponential kernel with diagonal noise term\ndef kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):\n    deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)\n    k = var * jnp.exp(-0.5 * deltaXsq)\n    if include_noise:\n        k += (noise + jitter) * jnp.eye(X.shape[0])\n    return k\n\n\ndef model(X, Y):\n    # set uninformative log-normal priors on our three kernel hyperparameters\n    var = numpyro.sample(\"kernel_var\", dist.LogNormal(0.0, 10.0))\n    noise = numpyro.sample(\"kernel_noise\", dist.LogNormal(0.0, 10.0))\n    length = numpyro.sample(\"kernel_length\", dist.LogNormal(0.0, 10.0))\n\n    # compute kernel\n    k = kernel(X, X, var, length, noise)\n\n    # sample Y according to the standard gaussian process formula\n    numpyro.sample(\n        \"Y\",\n        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),\n        obs=Y,\n    )\n\n\n# helper function for doing hmc inference\ndef run_inference(model, args, rng_key, X, Y):\n    start = time.time()\n    # demonstrate how to use different HMC initialization strategies\n    if args.init_strategy == \"value\":\n        init_strategy = init_to_value(\n            values={\"kernel_var\": 1.0, \"kernel_noise\": 0.05, \"kernel_length\": 0.5}\n        )\n    elif args.init_strategy == \"median\":\n        init_strategy = init_to_median(num_samples=10)\n    elif args.init_strategy == \"feasible\":\n        init_strategy = init_to_feasible()\n    elif args.init_strategy == \"sample\":\n        init_strategy = init_to_sample()\n    elif args.init_strategy == \"uniform\":\n        init_strategy = init_to_uniform(radius=1)\n    kernel = NUTS(model, init_strategy=init_strategy)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        thinning=args.thinning,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, X, Y)\n    mcmc.print_summary()\n    print(\"\\nMCMC elapsed time:\", time.time() - start)\n    return mcmc.get_samples()\n\n\n# do GP prediction for a given set of hyperparameters. this makes use of the well-known\n# formula for Gaussian process predictions\ndef predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True):\n    # compute kernels between train and test data, etc.\n    k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)\n    k_pX = kernel(X_test, X, var, length, noise, include_noise=False)\n    k_XX = kernel(X, X, var, length, noise, include_noise=True)\n\n    # since K_xx is symmetric positive-definite, we can use the more efficient and\n    # stable Cholesky decomposition instead of matrix inversion\n    if use_cholesky:\n        K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)\n        K = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))\n        mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, Y))\n    else:\n        K_xx_inv = jnp.linalg.inv(k_XX)\n        K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))\n        mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))\n\n    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(\n        rng_key, X_test.shape[:1]\n    )\n\n    # we return both the mean function and a sample from the posterior predictive for the\n    # given set of hyperparameters\n    return mean, mean + sigma_noise\n\n\n# create artificial regression dataset\ndef get_data(N=30, sigma_obs=0.15, N_test=400):\n    np.random.seed(0)\n    X = jnp.linspace(-1, 1, N)\n    Y = X + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0) * jnp.sin(4.0 * X)\n    Y += sigma_obs * np.random.randn(N)\n    Y -= jnp.mean(Y)\n    Y /= jnp.std(Y)\n\n    assert X.shape == (N,)\n    assert Y.shape == (N,)\n\n    X_test = jnp.linspace(-1.3, 1.3, N_test)\n\n    return X, Y, X_test\n\n\ndef main(args):\n    X, Y, X_test = get_data(N=args.num_data)\n\n    # do inference\n    rng_key, rng_key_predict = random.split(random.PRNGKey(0))\n    samples = run_inference(model, args, rng_key, X, Y)\n\n    # do prediction\n    vmap_args = (\n        random.split(rng_key_predict, samples[\"kernel_var\"].shape[0]),\n        samples[\"kernel_var\"],\n        samples[\"kernel_length\"],\n        samples[\"kernel_noise\"],\n    )\n    means, predictions = vmap(\n        lambda rng_key, var, length, noise: predict(\n            rng_key, X, Y, X_test, var, length, noise, use_cholesky=args.use_cholesky\n        )\n    )(*vmap_args)\n\n    mean_prediction = np.mean(means, axis=0)\n    percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)\n\n    # make plots\n    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)\n\n    # plot training data\n    ax.plot(X, Y, \"kx\")\n    # plot 90% confidence level of predictions\n    ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color=\"lightblue\")\n    # plot mean prediction\n    ax.plot(X_test, mean_prediction, \"blue\", ls=\"solid\", lw=2.0)\n    ax.set(xlabel=\"X\", ylabel=\"Y\", title=\"Mean predictions with 90% CI\")\n\n    plt.savefig(\"gp_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Gaussian Process example\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--thinning\", nargs=\"?\", default=2, type=int)\n    parser.add_argument(\"--num-data\", nargs=\"?\", default=25, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    parser.add_argument(\n        \"--init-strategy\",\n        default=\"median\",\n        type=str,\n        choices=[\"median\", \"feasible\", \"value\", \"uniform\", \"sample\"],\n    )\n    parser.add_argument(\"--no-cholesky\", dest=\"use_cholesky\", action=\"store_false\")\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}