{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Neal's Funnel\n\nThis example, which is adapted from [1], illustrates how to leverage non-centered\nparameterization using the :class:`~numpyro.handlers.reparam` handler.\nWe will examine the difference between two types of parameterizations on the\n10-dimensional Neal's funnel distribution. As we will see, HMC gets trouble at\nthe neck of the funnel if centered parameterization is used. On the contrary,\nthe problem can be solved by using non-centered parameterization.\n\nUsing non-centered parameterization through :class:`~numpyro.infer.reparam.LocScaleReparam`\nor :class:`~numpyro.infer.reparam.TransformReparam` in NumPyro has the same effect as\nthe automatic reparameterisation technique introduced in [2].\n\n**References:**\n\n    1. *Stan User's Guide*, https://mc-stan.org/docs/2_19/stan-users-guide/reparameterization-section.html\n    2. Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019), \"Automatic\n       Reparameterisation of Probabilistic Programs\", (https://arxiv.org/abs/1906.03028)\n\n<img src=\"file://../_static/img/examples/funnel.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nimport matplotlib.pyplot as plt\n\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.handlers import reparam\nfrom numpyro.infer import MCMC, NUTS, Predictive\nfrom numpyro.infer.reparam import LocScaleReparam\n\n\ndef model(dim=10):\n    y = numpyro.sample(\"y\", dist.Normal(0, 3))\n    numpyro.sample(\"x\", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))\n\n\nreparam_model = reparam(model, config={\"x\": LocScaleReparam(0)})\n\n\ndef run_inference(model, args, rng_key):\n    kernel = NUTS(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key)\n    mcmc.print_summary(exclude_deterministic=False)\n    return mcmc\n\n\ndef main(args):\n    rng_key = random.PRNGKey(0)\n\n    # do inference with centered parameterization\n    print(\n        \"============================= Centered Parameterization ==============================\"\n    )\n    mcmc = run_inference(model, args, rng_key)\n    samples = mcmc.get_samples()\n    diverging = mcmc.get_extra_fields()[\"diverging\"]\n\n    # do inference with non-centered parameterization\n    print(\n        \"\\n=========================== Non-centered Parameterization ============================\"\n    )\n    reparam_mcmc = run_inference(reparam_model, args, rng_key)\n    reparam_samples = reparam_mcmc.get_samples()\n    reparam_diverging = reparam_mcmc.get_extra_fields()[\"diverging\"]\n    # collect deterministic sites\n    reparam_samples = Predictive(\n        reparam_model, reparam_samples, return_sites=[\"x\", \"y\"]\n    )(random.PRNGKey(1))\n\n    # make plots\n    fig, (ax1, ax2) = plt.subplots(\n        2, 1, sharex=True, figsize=(8, 8), constrained_layout=True\n    )\n\n    ax1.plot(\n        samples[\"x\"][~diverging, 0],\n        samples[\"y\"][~diverging],\n        \"o\",\n        color=\"darkred\",\n        alpha=0.3,\n        label=\"Non-diverging\",\n    )\n    ax1.plot(\n        samples[\"x\"][diverging, 0],\n        samples[\"y\"][diverging],\n        \"o\",\n        color=\"lime\",\n        label=\"Diverging\",\n    )\n    ax1.set(\n        xlim=(-20, 20),\n        ylim=(-9, 9),\n        ylabel=\"y\",\n        title=\"Funnel samples with centered parameterization\",\n    )\n    ax1.legend()\n\n    ax2.plot(\n        reparam_samples[\"x\"][~reparam_diverging, 0],\n        reparam_samples[\"y\"][~reparam_diverging],\n        \"o\",\n        color=\"darkred\",\n        alpha=0.3,\n    )\n    ax2.plot(\n        reparam_samples[\"x\"][reparam_diverging, 0],\n        reparam_samples[\"y\"][reparam_diverging],\n        \"o\",\n        color=\"lime\",\n    )\n    ax2.set(\n        xlim=(-20, 20),\n        ylim=(-9, 9),\n        xlabel=\"x[0]\",\n        ylabel=\"y\",\n        title=\"Funnel samples with non-centered parameterization\",\n    )\n\n    plt.savefig(\"funnel_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(\n        description=\"Non-centered reparameterization example\"\n    )\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}