{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: MCMC Methods for Tall Data\n\nThis example illustrates the usages of various MCMC methods which are suitable for tall data:\n\n    - `algo=\"SA\"` uses the sample adaptive MCMC method in [1]\n    - `algo=\"HMCECS\"` uses the energy conserving subsampling method in [2]\n    - `algo=\"FlowHMCECS\"` utilizes a normalizing flow to neutralize the posterior\n      geometry into a Gaussian-like one. Then HMCECS is used to draw the posterior\n      samples. Currently, this method gives the best mixing rate among those methods.\n\n**References:**\n\n    1. *Sample Adaptive MCMC*,\n       Michael Zhu (2019)\n    2. *Hamiltonian Monte Carlo with energy conserving subsampling*,\n       Dang, K. D., Quiroz, M., Kohn, R., Minh-Ngoc, T., & Villani, M. (2019)\n    3. *NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport*,\n       Hoffman, M. et al. (2019)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport time\n\nimport matplotlib.pyplot as plt\n\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import COVTYPE, load_dataset\nfrom numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, SVI, Trace_ELBO, init_to_value\nfrom numpyro.infer.autoguide import AutoBNAFNormal\nfrom numpyro.infer.reparam import NeuTraReparam\n\n\ndef _load_dataset():\n    _, fetch = load_dataset(COVTYPE, shuffle=False)\n    features, labels = fetch()\n\n    # normalize features and add intercept\n    features = (features - features.mean(0)) / features.std(0)\n    features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])\n\n    # make binary feature\n    _, counts = jnp.unique(labels, return_counts=True)\n    specific_category = jnp.argmax(counts)\n    labels = labels == specific_category\n\n    N, dim = features.shape\n    print(\"Data shape:\", features.shape)\n    print(\n        \"Label distribution: {} has label 1, {} has label 0\".format(\n            labels.sum(), N - labels.sum()\n        )\n    )\n    return features, labels\n\n\ndef model(data, labels, subsample_size=None):\n    dim = data.shape[1]\n    coefs = numpyro.sample(\"coefs\", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))\n    with numpyro.plate(\"N\", data.shape[0], subsample_size=subsample_size) as idx:\n        logits = jnp.dot(data[idx], coefs)\n        return numpyro.sample(\"obs\", dist.Bernoulli(logits=logits), obs=labels[idx])\n\n\ndef benchmark_hmc(args, features, labels):\n    rng_key = random.PRNGKey(1)\n    start = time.time()\n    # a MAP estimate at the following source\n    # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117\n    ref_params = {\n        \"coefs\": jnp.array(\n            [\n                +2.03420663e00,\n                -3.53567265e-02,\n                -1.49223924e-01,\n                -3.07049364e-01,\n                -1.00028366e-01,\n                -1.46827862e-01,\n                -1.64167881e-01,\n                -4.20344204e-01,\n                +9.47479829e-02,\n                -1.12681836e-02,\n                +2.64442056e-01,\n                -1.22087866e-01,\n                -6.00568838e-02,\n                -3.79419506e-01,\n                -1.06668741e-01,\n                -2.97053963e-01,\n                -2.05253899e-01,\n                -4.69537191e-02,\n                -2.78072730e-02,\n                -1.43250525e-01,\n                -6.77954629e-02,\n                -4.34899796e-03,\n                +5.90927452e-02,\n                +7.23133609e-02,\n                +1.38526391e-02,\n                -1.24497898e-01,\n                -1.50733739e-02,\n                -2.68872194e-02,\n                -1.80925727e-02,\n                +3.47936489e-02,\n                +4.03552800e-02,\n                -9.98773426e-03,\n                +6.20188080e-02,\n                +1.15002751e-01,\n                +1.32145107e-01,\n                +2.69109547e-01,\n                +2.45785132e-01,\n                +1.19035013e-01,\n                -2.59744357e-02,\n                +9.94279515e-04,\n                +3.39266285e-02,\n                -1.44057125e-02,\n                -6.95222765e-02,\n                -7.52013028e-02,\n                +1.21171586e-01,\n                +2.29205526e-02,\n                +1.47308692e-01,\n                -8.34354162e-02,\n                -9.34122875e-02,\n                -2.97472421e-02,\n                -3.03937674e-01,\n                -1.70958012e-01,\n                -1.59496680e-01,\n                -1.88516974e-01,\n                -1.20889175e00,\n            ]\n        )\n    }\n    if args.algo == \"HMC\":\n        step_size = jnp.sqrt(0.5 / features.shape[0])\n        trajectory_length = step_size * args.num_steps\n        kernel = HMC(\n            model,\n            step_size=step_size,\n            trajectory_length=trajectory_length,\n            adapt_step_size=False,\n            dense_mass=args.dense_mass,\n        )\n        subsample_size = None\n    elif args.algo == \"NUTS\":\n        kernel = NUTS(model, dense_mass=args.dense_mass)\n        subsample_size = None\n    elif args.algo == \"HMCECS\":\n        subsample_size = 1000\n        inner_kernel = NUTS(\n            model,\n            init_strategy=init_to_value(values=ref_params),\n            dense_mass=args.dense_mass,\n        )\n        # note: if num_blocks=100, we'll update 10 index at each MCMC step\n        # so it took 50000 MCMC steps to iterative the whole dataset\n        kernel = HMCECS(\n            inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)\n        )\n    elif args.algo == \"SA\":\n        # NB: this kernel requires large num_warmup and num_samples\n        # and running on GPU is much faster than on CPU\n        kernel = SA(\n            model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)\n        )\n        subsample_size = None\n    elif args.algo == \"FlowHMCECS\":\n        subsample_size = 1000\n        guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])\n        svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())\n        svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)\n        params, losses = svi_result.params, svi_result.losses\n        plt.plot(losses)\n        plt.show()\n\n        neutra = NeuTraReparam(guide, params)\n        neutra_model = neutra.reparam(model)\n        neutra_ref_params = {\"auto_shared_latent\": jnp.zeros(55)}\n        # no need to adapt mass matrix if the flow does a good job\n        inner_kernel = NUTS(\n            neutra_model,\n            init_strategy=init_to_value(values=neutra_ref_params),\n            adapt_mass_matrix=False,\n        )\n        kernel = HMCECS(\n            inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)\n        )\n    else:\n        raise ValueError(\"Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.\")\n    mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)\n    mcmc.run(rng_key, features, labels, subsample_size, extra_fields=(\"accept_prob\",))\n    print(\"Mean accept prob:\", jnp.mean(mcmc.get_extra_fields()[\"accept_prob\"]))\n    mcmc.print_summary(exclude_deterministic=False)\n    print(\"\\nMCMC elapsed time:\", time.time() - start)\n\n\ndef main(args):\n    features, labels = _load_dataset()\n    benchmark_hmc(args, features, labels)\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"parse args\")\n    parser.add_argument(\n        \"-n\", \"--num-samples\", default=1000, type=int, help=\"number of samples\"\n    )\n    parser.add_argument(\n        \"--num-warmup\", default=1000, type=int, help=\"number of warmup steps\"\n    )\n    parser.add_argument(\n        \"--num-steps\", default=10, type=int, help='number of steps (for \"HMC\")'\n    )\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\n        \"--algo\",\n        default=\"HMCECS\",\n        type=str,\n        help='whether to run \"HMC\", \"NUTS\", \"HMCECS\", \"SA\" or \"FlowHMCECS\"',\n    )\n    parser.add_argument(\"--dense-mass\", action=\"store_true\")\n    parser.add_argument(\"--x64\", action=\"store_true\")\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n    if args.x64:\n        numpyro.enable_x64()\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}