
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/gaussian_shells.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_gaussian_shells.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_gaussian_shells.py:


Example: Nested Sampling for Gaussian Shells
============================================

This example illustrates the usage of the contrib class NestedSampler,
which is a wrapper of `jaxns` library ([1]) to be used for NumPyro models.

Here we will replicate the Gaussian Shells demo at [2] and compare against
NUTS sampler.

**References:**

    1. jaxns library: https://github.com/Joshuaalbert/jaxns
    2. dynesty's Gaussian Shells demo:
       https://github.com/joshspeagle/dynesty/blob/master/demos/Examples%20--%20Gaussian%20Shells.ipynb

.. GENERATED FROM PYTHON SOURCE LINES 20-139

.. code-block:: default


    import argparse

    import matplotlib.pyplot as plt

    from jax import random
    import jax.numpy as jnp

    import numpyro
    from numpyro.contrib.nested_sampling import NestedSampler
    import numpyro.distributions as dist
    from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs


    class GaussianShell(dist.Distribution):
        support = dist.constraints.real_vector

        def __init__(self, loc, radius, width):
            self.loc, self.radius, self.width = loc, radius, width
            super().__init__(batch_shape=loc.shape[:-1], event_shape=loc.shape[-1:])

        def sample(self, key, sample_shape=()):
            return jnp.zeros(
                sample_shape + self.shape()
            )  # a dummy sample to initialize the samplers

        def log_prob(self, value):
            normalizer = (-0.5) * (jnp.log(2.0 * jnp.pi) + 2.0 * jnp.log(self.width))
            d = jnp.linalg.norm(value - self.loc, axis=-1)
            return normalizer - 0.5 * ((d - self.radius) / self.width) ** 2


    def model(center1, center2, radius, width, enum=False):
        z = numpyro.sample(
            "z", dist.Bernoulli(0.5), infer={"enumerate": "parallel"} if enum else {}
        )
        x = numpyro.sample("x", dist.Uniform(-6.0, 6.0).expand([2]).to_event(1))
        center = jnp.stack([center1, center2])[z]
        numpyro.sample("shell", GaussianShell(center, radius, width), obs=x)


    def run_inference(args, data):
        print("=== Performing Nested Sampling ===")
        ns = NestedSampler(model)
        ns.run(random.PRNGKey(0), **data, enum=args.enum)
        ns.print_summary()
        # samples obtained from nested sampler are weighted, so
        # we need to provide random key to resample from those weighted samples
        ns_samples = ns.get_samples(random.PRNGKey(1), num_samples=args.num_samples)

        print("\n=== Performing MCMC Sampling ===")
        if args.enum:
            mcmc = MCMC(
                NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples
            )
        else:
            mcmc = MCMC(
                DiscreteHMCGibbs(NUTS(model)),
                num_warmup=args.num_warmup,
                num_samples=args.num_samples,
            )
        mcmc.run(random.PRNGKey(2), **data, enum=args.enum)
        mcmc.print_summary()
        mcmc_samples = mcmc.get_samples()

        return ns_samples["x"], mcmc_samples["x"]


    def main(args):
        data = dict(
            radius=2.0,
            width=0.1,
            center1=jnp.array([-3.5, 0.0]),
            center2=jnp.array([3.5, 0.0]),
        )
        ns_samples, mcmc_samples = run_inference(args, data)

        # plotting
        fig, (ax1, ax2) = plt.subplots(
            2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
        )

        ax1.plot(mcmc_samples[:, 0], mcmc_samples[:, 1], "ro", alpha=0.2)
        ax1.set(
            xlim=(-6, 6),
            ylim=(-2.5, 2.5),
            ylabel="x[1]",
            title="Gaussian-shell samples using NUTS",
        )

        ax2.plot(ns_samples[:, 0], ns_samples[:, 1], "ro", alpha=0.2)
        ax2.set(
            xlim=(-6, 6),
            ylim=(-2.5, 2.5),
            xlabel="x[0]",
            ylabel="x[1]",
            title="Gaussian-shell samples using Nested Sampler",
        )

        plt.savefig("gaussian_shells_plot.pdf")


    if __name__ == "__main__":
        assert numpyro.__version__.startswith("0.13.2")
        parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
        parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)
        parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
        parser.add_argument(
            "--enum",
            action="store_true",
            default=False,
            help="whether to enumerate over the discrete latent variable",
        )
        parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
        args = parser.parse_args()

        numpyro.set_platform(args.device)

        main(args)


.. _sphx_glr_download_examples_gaussian_shells.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: gaussian_shells.py <gaussian_shells.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: gaussian_shells.ipynb <gaussian_shells.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
