{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Hamiltonian Monte Carlo with Energy Conserving Subsampling\n\nThis example illustrates the use of data subsampling in HMC using Energy Conserving Subsampling. Data subsampling\nis applicable when the likelihood factorizes as a product of N terms.\n\n**References:**\n\n    1. *Hamiltonian Monte Carlo with energy conserving subsampling*,\n       Dang, K. D., Quiroz, M., Kohn, R., Minh-Ngoc, T., & Villani, M. (2019)\n\n<img src=\"file://../_static/img/examples/hmcecs.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import HIGGS, load_dataset\nfrom numpyro.infer import HMC, HMCECS, MCMC, NUTS, SVI, Trace_ELBO, autoguide\n\n\ndef model(data, obs, subsample_size):\n    n, m = data.shape\n    theta = numpyro.sample(\"theta\", dist.Normal(jnp.zeros(m), 0.5 * jnp.ones(m)))\n    with numpyro.plate(\"N\", n, subsample_size=subsample_size):\n        batch_feats = numpyro.subsample(data, event_dim=1)\n        batch_obs = numpyro.subsample(obs, event_dim=0)\n        numpyro.sample(\n            \"obs\", dist.Bernoulli(logits=theta @ batch_feats.T), obs=batch_obs\n        )\n\n\ndef run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):\n    svi_key, mcmc_key = random.split(hmcecs_key)\n\n    # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy)\n    optimizer = numpyro.optim.Adam(step_size=1e-3)\n    guide = autoguide.AutoDelta(model)\n    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())\n    svi_result = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size)\n    params, losses = svi_result.params, svi_result.losses\n    ref_params = {\"theta\": params[\"theta_auto_loc\"]}\n\n    # taylor proxy estimates log likelihood (ll) by\n    # taylor_expansion(ll, theta_curr) +\n    #     sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params\n    proxy = HMCECS.taylor_proxy(ref_params)\n\n    kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy)\n    mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)\n\n    mcmc.run(mcmc_key, data, obs, args.subsample_size)\n    mcmc.print_summary()\n    return losses, mcmc.get_samples()\n\n\ndef run_hmc(mcmc_key, args, data, obs, kernel):\n    mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)\n    mcmc.run(mcmc_key, data, obs, None)\n    mcmc.print_summary()\n    return mcmc.get_samples()\n\n\ndef main(args):\n    assert (\n        11_000_000 >= args.num_datapoints\n    ), \"11,000,000 data points in the Higgs dataset\"\n    # full dataset takes hours for plain hmc!\n    if args.dataset == \"higgs\":\n        _, fetch = load_dataset(\n            HIGGS, shuffle=False, num_datapoints=args.num_datapoints\n        )\n        data, obs = fetch()\n    else:\n        data, obs = (np.random.normal(size=(10, 28)), np.ones(10))\n\n    hmcecs_key, hmc_key = random.split(random.PRNGKey(args.rng_seed))\n\n    # choose inner_kernel\n    if args.inner_kernel == \"hmc\":\n        inner_kernel = HMC(model)\n    else:\n        inner_kernel = NUTS(model)\n\n    start = time.time()\n    losses, hmcecs_samples = run_hmcecs(hmcecs_key, args, data, obs, inner_kernel)\n    hmcecs_runtime = time.time() - start\n\n    start = time.time()\n    hmc_samples = run_hmc(hmc_key, args, data, obs, inner_kernel)\n    hmc_runtime = time.time() - start\n\n    summary_plot(losses, hmc_samples, hmcecs_samples, hmc_runtime, hmcecs_runtime)\n\n\ndef summary_plot(losses, hmc_samples, hmcecs_samples, hmc_runtime, hmcecs_runtime):\n    fig, ax = plt.subplots(2, 2)\n    ax[0, 0].plot(losses, \"r\")\n    ax[0, 0].set_title(\"SVI losses\")\n    ax[0, 0].set_ylabel(\"ELBO\")\n\n    if hmc_runtime > hmcecs_runtime:\n        ax[0, 1].bar([0], hmc_runtime, label=\"hmc\", color=\"b\")\n        ax[0, 1].bar([0], hmcecs_runtime, label=\"hmcecs\", color=\"r\")\n    else:\n        ax[0, 1].bar([0], hmcecs_runtime, label=\"hmcecs\", color=\"r\")\n        ax[0, 1].bar([0], hmc_runtime, label=\"hmc\", color=\"b\")\n    ax[0, 1].set_title(\"Runtime\")\n    ax[0, 1].set_ylabel(\"Seconds\")\n    ax[0, 1].legend()\n    ax[0, 1].set_xticks([])\n\n    ax[1, 0].plot(jnp.sort(hmc_samples[\"theta\"].mean(0)), \"or\")\n    ax[1, 0].plot(jnp.sort(hmcecs_samples[\"theta\"].mean(0)), \"b\")\n    ax[1, 0].set_title(r\"$\\mathrm{\\mathbb{E}}[\\theta]$\")\n\n    ax[1, 1].plot(jnp.sort(hmc_samples[\"theta\"].var(0)), \"or\")\n    ax[1, 1].plot(jnp.sort(hmcecs_samples[\"theta\"].var(0)), \"b\")\n    ax[1, 1].set_title(r\"Var$[\\theta]$\")\n\n    for a in ax[1, :]:\n        a.set_xticks([])\n\n    fig.tight_layout()\n    fig.savefig(\"hmcecs_plot.pdf\", bbox_inches=\"tight\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        \"Hamiltonian Monte Carlo with Energy Conserving Subsampling\"\n    )\n    parser.add_argument(\"--subsample_size\", type=int, default=1300)\n    parser.add_argument(\"--num_svi_steps\", type=int, default=5000)\n    parser.add_argument(\"--num_blocks\", type=int, default=100)\n    parser.add_argument(\"--num_warmup\", type=int, default=500)\n    parser.add_argument(\"--num_samples\", type=int, default=500)\n    parser.add_argument(\"--num_datapoints\", type=int, default=1_500_000)\n    parser.add_argument(\n        \"--dataset\", type=str, choices=[\"higgs\", \"mock\"], default=\"higgs\"\n    )\n    parser.add_argument(\n        \"--inner_kernel\", type=str, choices=[\"nuts\", \"hmc\"], default=\"nuts\"\n    )\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, choices=[\"cpu\", \"gpu\"])\n    parser.add_argument(\n        \"--rng_seed\", default=37, type=int, help=\"random number generator seed\"\n    )\n\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}