{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Bayesian Neural Network\n\nWe demonstrate how to use NUTS to do inference on a simple (small)\nBayesian neural network with two hidden layers.\n\n<img src=\"file://../_static/img/examples/bnn.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\nimport time\n\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom jax import vmap\nimport jax.numpy as jnp\nimport jax.random as random\n\nimport numpyro\nfrom numpyro import handlers\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS\n\nmatplotlib.use(\"Agg\")  # noqa: E402\n\n\n# the non-linearity we use in our neural network\ndef nonlin(x):\n    return jnp.tanh(x)\n\n\n# a two-layer bayesian neural network with computational flow\n# given by D_X => D_H => D_H => D_Y where D_H is the number of\n# hidden units. (note we indicate tensor dimensions in the comments)\ndef model(X, Y, D_H, D_Y=1):\n    N, D_X = X.shape\n\n    # sample first layer (we put unit normal priors on all weights)\n    w1 = numpyro.sample(\"w1\", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))\n    assert w1.shape == (D_X, D_H)\n    z1 = nonlin(jnp.matmul(X, w1))  # <= first layer of activations\n    assert z1.shape == (N, D_H)\n\n    # sample second layer\n    w2 = numpyro.sample(\"w2\", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))\n    assert w2.shape == (D_H, D_H)\n    z2 = nonlin(jnp.matmul(z1, w2))  # <= second layer of activations\n    assert z2.shape == (N, D_H)\n\n    # sample final layer of weights and neural network output\n    w3 = numpyro.sample(\"w3\", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))\n    assert w3.shape == (D_H, D_Y)\n    z3 = jnp.matmul(z2, w3)  # <= output of the neural network\n    assert z3.shape == (N, D_Y)\n\n    if Y is not None:\n        assert z3.shape == Y.shape\n\n    # we put a prior on the observation noise\n    prec_obs = numpyro.sample(\"prec_obs\", dist.Gamma(3.0, 1.0))\n    sigma_obs = 1.0 / jnp.sqrt(prec_obs)\n\n    # observe data\n    with numpyro.plate(\"data\", N):\n        # note we use to_event(1) because each observation has shape (1,)\n        numpyro.sample(\"Y\", dist.Normal(z3, sigma_obs).to_event(1), obs=Y)\n\n\n# helper function for HMC inference\ndef run_inference(model, args, rng_key, X, Y, D_H):\n    start = time.time()\n    kernel = NUTS(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, X, Y, D_H)\n    mcmc.print_summary()\n    print(\"\\nMCMC elapsed time:\", time.time() - start)\n    return mcmc.get_samples()\n\n\n# helper function for prediction\ndef predict(model, rng_key, samples, X, D_H):\n    model = handlers.substitute(handlers.seed(model, rng_key), samples)\n    # note that Y will be sampled in the model because we pass Y=None here\n    model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H)\n    return model_trace[\"Y\"][\"value\"]\n\n\n# create artificial regression dataset\ndef get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):\n    D_Y = 1  # create 1d outputs\n    np.random.seed(0)\n    X = jnp.linspace(-1, 1, N)\n    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))\n    W = 0.5 * np.random.randn(D_X)\n    Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])\n    Y += sigma_obs * np.random.randn(N)\n    Y = Y[:, np.newaxis]\n    Y -= jnp.mean(Y)\n    Y /= jnp.std(Y)\n\n    assert X.shape == (N, D_X)\n    assert Y.shape == (N, D_Y)\n\n    X_test = jnp.linspace(-1.3, 1.3, N_test)\n    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))\n\n    return X, Y, X_test\n\n\ndef main(args):\n    N, D_X, D_H = args.num_data, 3, args.num_hidden\n    X, Y, X_test = get_data(N=N, D_X=D_X)\n\n    # do inference\n    rng_key, rng_key_predict = random.split(random.PRNGKey(0))\n    samples = run_inference(model, args, rng_key, X, Y, D_H)\n\n    # predict Y_test at inputs X_test\n    vmap_args = (\n        samples,\n        random.split(rng_key_predict, args.num_samples * args.num_chains),\n    )\n    predictions = vmap(\n        lambda samples, rng_key: predict(model, rng_key, samples, X_test, D_H)\n    )(*vmap_args)\n    predictions = predictions[..., 0]\n\n    # compute mean prediction and confidence interval around median\n    mean_prediction = jnp.mean(predictions, axis=0)\n    percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)\n\n    # make plots\n    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)\n\n    # plot training data\n    ax.plot(X[:, 1], Y[:, 0], \"kx\")\n    # plot 90% confidence level of predictions\n    ax.fill_between(\n        X_test[:, 1], percentiles[0, :], percentiles[1, :], color=\"lightblue\"\n    )\n    # plot mean prediction\n    ax.plot(X_test[:, 1], mean_prediction, \"blue\", ls=\"solid\", lw=2.0)\n    ax.set(xlabel=\"X\", ylabel=\"Y\", title=\"Mean predictions with 90% CI\")\n\n    plt.savefig(\"bnn_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Bayesian neural network example\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=2000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--num-data\", nargs=\"?\", default=100, type=int)\n    parser.add_argument(\"--num-hidden\", nargs=\"?\", default=5, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}