{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Neural Transport\n\nThis example illustrates how to use a trained AutoBNAFNormal autoguide to transform a posterior to a\nGaussian-like one. The transform will be used to get better mixing rate for NUTS sampler.\n\n**References:**\n\n    1. Hoffman, M. et al. (2019), \"NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport\",\n       (https://arxiv.org/abs/1903.03704)\n\n<img src=\"file://../_static/img/examples/neutra.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nfrom matplotlib.gridspec import GridSpec\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\nfrom jax import random\nimport jax.numpy as jnp\nfrom jax.scipy.special import logsumexp\n\nimport numpyro\nfrom numpyro import optim\nfrom numpyro.diagnostics import print_summary\nimport numpyro.distributions as dist\nfrom numpyro.distributions import constraints\nfrom numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO\nfrom numpyro.infer.autoguide import AutoBNAFNormal\nfrom numpyro.infer.reparam import NeuTraReparam\n\n\nclass DualMoonDistribution(dist.Distribution):\n    support = constraints.real_vector\n\n    def __init__(self):\n        super(DualMoonDistribution, self).__init__(event_shape=(2,))\n\n    def sample(self, key, sample_shape=()):\n        # it is enough to return an arbitrary sample with correct shape\n        return jnp.zeros(sample_shape + self.event_shape)\n\n    def log_prob(self, x):\n        term1 = 0.5 * ((jnp.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2\n        term2 = -0.5 * ((x[..., :1] + jnp.array([-2.0, 2.0])) / 0.6) ** 2\n        pe = term1 - logsumexp(term2, axis=-1)\n        return -pe\n\n\ndef dual_moon_model():\n    numpyro.sample(\"x\", DualMoonDistribution())\n\n\ndef main(args):\n    print(\"Start vanilla HMC...\")\n    nuts_kernel = NUTS(dual_moon_model)\n    mcmc = MCMC(\n        nuts_kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(random.PRNGKey(0))\n    mcmc.print_summary()\n    vanilla_samples = mcmc.get_samples()[\"x\"].copy()\n\n    guide = AutoBNAFNormal(\n        dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]\n    )\n    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), Trace_ELBO())\n\n    print(\"Start training guide...\")\n    svi_result = svi.run(random.PRNGKey(1), args.num_iters)\n    print(\"Finish training guide. Extract samples...\")\n    guide_samples = guide.sample_posterior(\n        random.PRNGKey(2), svi_result.params, sample_shape=(args.num_samples,)\n    )[\"x\"].copy()\n\n    print(\"\\nStart NeuTra HMC...\")\n    neutra = NeuTraReparam(guide, svi_result.params)\n    neutra_model = neutra.reparam(dual_moon_model)\n    nuts_kernel = NUTS(neutra_model)\n    mcmc = MCMC(\n        nuts_kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(random.PRNGKey(3))\n    mcmc.print_summary()\n    zs = mcmc.get_samples(group_by_chain=True)[\"auto_shared_latent\"]\n    print(\"Transform samples into unwarped space...\")\n    samples = neutra.transform_sample(zs)\n    print_summary(samples)\n    zs = zs.reshape(-1, 2)\n    samples = samples[\"x\"].reshape(-1, 2).copy()\n\n    # make plots\n\n    # guide samples (for plotting)\n    guide_base_samples = dist.Normal(jnp.zeros(2), 1.0).sample(\n        random.PRNGKey(4), (1000,)\n    )\n    guide_trans_samples = neutra.transform_sample(guide_base_samples)[\"x\"]\n\n    x1 = jnp.linspace(-3, 3, 100)\n    x2 = jnp.linspace(-3, 3, 100)\n    X1, X2 = jnp.meshgrid(x1, x2)\n    P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1)))\n\n    fig = plt.figure(figsize=(12, 8), constrained_layout=True)\n    gs = GridSpec(2, 3, figure=fig)\n    ax1 = fig.add_subplot(gs[0, 0])\n    ax2 = fig.add_subplot(gs[1, 0])\n    ax3 = fig.add_subplot(gs[0, 1])\n    ax4 = fig.add_subplot(gs[1, 1])\n    ax5 = fig.add_subplot(gs[0, 2])\n    ax6 = fig.add_subplot(gs[1, 2])\n\n    ax1.plot(svi_result.losses[1000:])\n    ax1.set_title(\"Autoguide training loss\\n(after 1000 steps)\")\n\n    ax2.contourf(X1, X2, P, cmap=\"OrRd\")\n    sns.kdeplot(x=guide_samples[:, 0], y=guide_samples[:, 1], n_levels=30, ax=ax2)\n    ax2.set(\n        xlim=[-3, 3],\n        ylim=[-3, 3],\n        xlabel=\"x0\",\n        ylabel=\"x1\",\n        title=\"Posterior using\\nAutoBNAFNormal guide\",\n    )\n\n    sns.scatterplot(\n        x=guide_base_samples[:, 0],\n        y=guide_base_samples[:, 1],\n        ax=ax3,\n        hue=guide_trans_samples[:, 0] < 0.0,\n    )\n    ax3.set(\n        xlim=[-3, 3],\n        ylim=[-3, 3],\n        xlabel=\"x0\",\n        ylabel=\"x1\",\n        title=\"AutoBNAFNormal base samples\\n(True=left moon; False=right moon)\",\n    )\n\n    ax4.contourf(X1, X2, P, cmap=\"OrRd\")\n    sns.kdeplot(x=vanilla_samples[:, 0], y=vanilla_samples[:, 1], n_levels=30, ax=ax4)\n    ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], \"bo-\", alpha=0.5)\n    ax4.set(\n        xlim=[-3, 3],\n        ylim=[-3, 3],\n        xlabel=\"x0\",\n        ylabel=\"x1\",\n        title=\"Posterior using\\nvanilla HMC sampler\",\n    )\n\n    sns.scatterplot(\n        x=zs[:, 0],\n        y=zs[:, 1],\n        ax=ax5,\n        hue=samples[:, 0] < 0.0,\n        s=30,\n        alpha=0.5,\n        edgecolor=\"none\",\n    )\n    ax5.set(\n        xlim=[-5, 5],\n        ylim=[-5, 5],\n        xlabel=\"x0\",\n        ylabel=\"x1\",\n        title=\"Samples from the\\nwarped posterior - p(z)\",\n    )\n\n    ax6.contourf(X1, X2, P, cmap=\"OrRd\")\n    sns.kdeplot(x=samples[:, 0], y=samples[:, 1], n_levels=30, ax=ax6)\n    ax6.plot(samples[-50:, 0], samples[-50:, 1], \"bo-\", alpha=0.2)\n    ax6.set(\n        xlim=[-3, 3],\n        ylim=[-3, 3],\n        xlabel=\"x0\",\n        ylabel=\"x1\",\n        title=\"Posterior using\\nNeuTra HMC sampler\",\n    )\n\n    plt.savefig(\"neutra.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"NeuTra HMC\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=4000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--hidden-factor\", nargs=\"?\", default=8, type=int)\n    parser.add_argument(\"--num-iters\", nargs=\"?\", default=10000, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}