{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Nested Sampling for Gaussian Shells\n\nThis example illustrates the usage of the contrib class NestedSampler,\nwhich is a wrapper of `jaxns` library ([1]) to be used for NumPyro models.\n\nHere we will replicate the Gaussian Shells demo at [2] and compare against\nNUTS sampler.\n\n**References:**\n\n    1. jaxns library: https://github.com/Joshuaalbert/jaxns\n    2. dynesty's Gaussian Shells demo:\n       https://github.com/joshspeagle/dynesty/blob/master/demos/Examples%20--%20Gaussian%20Shells.ipynb\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\n\nimport matplotlib.pyplot as plt\n\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nfrom numpyro.contrib.nested_sampling import NestedSampler\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs\n\n\nclass GaussianShell(dist.Distribution):\n    support = dist.constraints.real_vector\n\n    def __init__(self, loc, radius, width):\n        self.loc, self.radius, self.width = loc, radius, width\n        super().__init__(batch_shape=loc.shape[:-1], event_shape=loc.shape[-1:])\n\n    def sample(self, key, sample_shape=()):\n        return jnp.zeros(\n            sample_shape + self.shape()\n        )  # a dummy sample to initialize the samplers\n\n    def log_prob(self, value):\n        normalizer = (-0.5) * (jnp.log(2.0 * jnp.pi) + 2.0 * jnp.log(self.width))\n        d = jnp.linalg.norm(value - self.loc, axis=-1)\n        return normalizer - 0.5 * ((d - self.radius) / self.width) ** 2\n\n\ndef model(center1, center2, radius, width, enum=False):\n    z = numpyro.sample(\n        \"z\", dist.Bernoulli(0.5), infer={\"enumerate\": \"parallel\"} if enum else {}\n    )\n    x = numpyro.sample(\"x\", dist.Uniform(-6.0, 6.0).expand([2]).to_event(1))\n    center = jnp.stack([center1, center2])[z]\n    numpyro.sample(\"shell\", GaussianShell(center, radius, width), obs=x)\n\n\ndef run_inference(args, data):\n    print(\"=== Performing Nested Sampling ===\")\n    ns = NestedSampler(model)\n    ns.run(random.PRNGKey(0), **data, enum=args.enum)\n    ns.print_summary()\n    # samples obtained from nested sampler are weighted, so\n    # we need to provide random key to resample from those weighted samples\n    ns_samples = ns.get_samples(random.PRNGKey(1), num_samples=args.num_samples)\n\n    print(\"\\n=== Performing MCMC Sampling ===\")\n    if args.enum:\n        mcmc = MCMC(\n            NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples\n        )\n    else:\n        mcmc = MCMC(\n            DiscreteHMCGibbs(NUTS(model)),\n            num_warmup=args.num_warmup,\n            num_samples=args.num_samples,\n        )\n    mcmc.run(random.PRNGKey(2), **data, enum=args.enum)\n    mcmc.print_summary()\n    mcmc_samples = mcmc.get_samples()\n\n    return ns_samples[\"x\"], mcmc_samples[\"x\"]\n\n\ndef main(args):\n    data = dict(\n        radius=2.0,\n        width=0.1,\n        center1=jnp.array([-3.5, 0.0]),\n        center2=jnp.array([3.5, 0.0]),\n    )\n    ns_samples, mcmc_samples = run_inference(args, data)\n\n    # plotting\n    fig, (ax1, ax2) = plt.subplots(\n        2, 1, sharex=True, figsize=(8, 8), constrained_layout=True\n    )\n\n    ax1.plot(mcmc_samples[:, 0], mcmc_samples[:, 1], \"ro\", alpha=0.2)\n    ax1.set(\n        xlim=(-6, 6),\n        ylim=(-2.5, 2.5),\n        ylabel=\"x[1]\",\n        title=\"Gaussian-shell samples using NUTS\",\n    )\n\n    ax2.plot(ns_samples[:, 0], ns_samples[:, 1], \"ro\", alpha=0.2)\n    ax2.set(\n        xlim=(-6, 6),\n        ylim=(-2.5, 2.5),\n        xlabel=\"x[0]\",\n        ylabel=\"x[1]\",\n        title=\"Gaussian-shell samples using Nested Sampler\",\n    )\n\n    plt.savefig(\"gaussian_shells_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Nested sampler for Gaussian shells\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=10000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\n        \"--enum\",\n        action=\"store_true\",\n        default=False,\n        help=\"whether to enumerate over the discrete latent variable\",\n    )\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}