
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/bnn.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_bnn.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_bnn.py:


Example: Bayesian Neural Network
================================

We demonstrate how to use NUTS to do inference on a simple (small)
Bayesian neural network with two hidden layers.

.. image:: ../_static/img/examples/bnn.png
    :align: center

.. GENERATED FROM PYTHON SOURCE LINES 14-177

.. code-block:: default


    import argparse
    import os
    import time

    import matplotlib
    import matplotlib.pyplot as plt
    import numpy as np

    from jax import vmap
    import jax.numpy as jnp
    import jax.random as random

    import numpyro
    from numpyro import handlers
    import numpyro.distributions as dist
    from numpyro.infer import MCMC, NUTS

    matplotlib.use("Agg")  # noqa: E402


    # the non-linearity we use in our neural network
    def nonlin(x):
        return jnp.tanh(x)


    # a two-layer bayesian neural network with computational flow
    # given by D_X => D_H => D_H => D_Y where D_H is the number of
    # hidden units. (note we indicate tensor dimensions in the comments)
    def model(X, Y, D_H, D_Y=1):
        N, D_X = X.shape

        # sample first layer (we put unit normal priors on all weights)
        w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
        assert w1.shape == (D_X, D_H)
        z1 = nonlin(jnp.matmul(X, w1))  # <= first layer of activations
        assert z1.shape == (N, D_H)

        # sample second layer
        w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
        assert w2.shape == (D_H, D_H)
        z2 = nonlin(jnp.matmul(z1, w2))  # <= second layer of activations
        assert z2.shape == (N, D_H)

        # sample final layer of weights and neural network output
        w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
        assert w3.shape == (D_H, D_Y)
        z3 = jnp.matmul(z2, w3)  # <= output of the neural network
        assert z3.shape == (N, D_Y)

        if Y is not None:
            assert z3.shape == Y.shape

        # we put a prior on the observation noise
        prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
        sigma_obs = 1.0 / jnp.sqrt(prec_obs)

        # observe data
        with numpyro.plate("data", N):
            # note we use to_event(1) because each observation has shape (1,)
            numpyro.sample("Y", dist.Normal(z3, sigma_obs).to_event(1), obs=Y)


    # helper function for HMC inference
    def run_inference(model, args, rng_key, X, Y, D_H):
        start = time.time()
        kernel = NUTS(model)
        mcmc = MCMC(
            kernel,
            num_warmup=args.num_warmup,
            num_samples=args.num_samples,
            num_chains=args.num_chains,
            progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
        )
        mcmc.run(rng_key, X, Y, D_H)
        mcmc.print_summary()
        print("\nMCMC elapsed time:", time.time() - start)
        return mcmc.get_samples()


    # helper function for prediction
    def predict(model, rng_key, samples, X, D_H):
        model = handlers.substitute(handlers.seed(model, rng_key), samples)
        # note that Y will be sampled in the model because we pass Y=None here
        model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H)
        return model_trace["Y"]["value"]


    # create artificial regression dataset
    def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
        D_Y = 1  # create 1d outputs
        np.random.seed(0)
        X = jnp.linspace(-1, 1, N)
        X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
        W = 0.5 * np.random.randn(D_X)
        Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
        Y += sigma_obs * np.random.randn(N)
        Y = Y[:, np.newaxis]
        Y -= jnp.mean(Y)
        Y /= jnp.std(Y)

        assert X.shape == (N, D_X)
        assert Y.shape == (N, D_Y)

        X_test = jnp.linspace(-1.3, 1.3, N_test)
        X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

        return X, Y, X_test


    def main(args):
        N, D_X, D_H = args.num_data, 3, args.num_hidden
        X, Y, X_test = get_data(N=N, D_X=D_X)

        # do inference
        rng_key, rng_key_predict = random.split(random.PRNGKey(0))
        samples = run_inference(model, args, rng_key, X, Y, D_H)

        # predict Y_test at inputs X_test
        vmap_args = (
            samples,
            random.split(rng_key_predict, args.num_samples * args.num_chains),
        )
        predictions = vmap(
            lambda samples, rng_key: predict(model, rng_key, samples, X_test, D_H)
        )(*vmap_args)
        predictions = predictions[..., 0]

        # compute mean prediction and confidence interval around median
        mean_prediction = jnp.mean(predictions, axis=0)
        percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

        # make plots
        fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

        # plot training data
        ax.plot(X[:, 1], Y[:, 0], "kx")
        # plot 90% confidence level of predictions
        ax.fill_between(
            X_test[:, 1], percentiles[0, :], percentiles[1, :], color="lightblue"
        )
        # plot mean prediction
        ax.plot(X_test[:, 1], mean_prediction, "blue", ls="solid", lw=2.0)
        ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

        plt.savefig("bnn_plot.pdf")


    if __name__ == "__main__":
        assert numpyro.__version__.startswith("0.13.2")
        parser = argparse.ArgumentParser(description="Bayesian neural network example")
        parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
        parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
        parser.add_argument("--num-chains", nargs="?", default=1, type=int)
        parser.add_argument("--num-data", nargs="?", default=100, type=int)
        parser.add_argument("--num-hidden", nargs="?", default=5, type=int)
        parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
        args = parser.parse_args()

        numpyro.set_platform(args.device)
        numpyro.set_host_device_count(args.num_chains)

        main(args)


.. _sphx_glr_download_examples_bnn.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: bnn.py <bnn.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: bnn.ipynb <bnn.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
