{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: AR2 process\n\nIn this example we show how to use ``jax.lax.scan``\nto avoid writing a (slow) Python for-loop. In this toy\nexample, with ``--num-data=1000``, the improvement is\nof almost almost 3x.\n\nTo demonstrate, we will be implementing an AR2 process.\nThe idea is that we have some times series\n\n\\begin{align}y_0, y_1, ..., y_T\\end{align}\n\nand we seek parameters $c$, $\\alpha_1$, and $\\alpha_2$\nsuch that for each $t$ between $2$ and $T$, we have\n\n\\begin{align}y_t = c + \\alpha_1 y_{t-1} + \\alpha_2 y_{t-2} + \\epsilon_t\\end{align}\n\nwhere $\\epsilon_t$ is an error term.\n\n<img src=\"file://../_static/img/examples/ar2.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\nimport time\n\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nfrom numpyro.contrib.control_flow import scan\nimport numpyro.distributions as dist\n\n\ndef ar2_scan(y):\n    alpha_1 = numpyro.sample(\"alpha_1\", dist.Normal(0, 1))\n    alpha_2 = numpyro.sample(\"alpha_2\", dist.Normal(0, 1))\n    const = numpyro.sample(\"const\", dist.Normal(0, 1))\n    sigma = numpyro.sample(\"sigma\", dist.HalfNormal(1))\n\n    def transition(carry, _):\n        y_prev, y_prev_prev = carry\n        m_t = const + alpha_1 * y_prev + alpha_2 * y_prev_prev\n        y_t = numpyro.sample(\"y\", dist.Normal(m_t, sigma))\n        carry = (y_t, y_prev)\n        return carry, None\n\n    timesteps = jnp.arange(y.shape[0] - 2)\n    init = (y[1], y[0])\n\n    with numpyro.handlers.condition(data={\"y\": y[2:]}):\n        scan(transition, init, timesteps)\n\n\ndef ar2_for_loop(y):\n    alpha_1 = numpyro.sample(\"alpha_1\", dist.Normal(0, 1))\n    alpha_2 = numpyro.sample(\"alpha_2\", dist.Normal(0, 1))\n    const = numpyro.sample(\"const\", dist.Normal(0, 1))\n    sigma = numpyro.sample(\"sigma\", dist.HalfNormal(1))\n\n    y_prev = y[1]\n    y_prev_prev = y[0]\n\n    for i in range(2, len(y)):\n        m_t = const + alpha_1 * y_prev + alpha_2 * y_prev_prev\n        y_t = numpyro.sample(\"y_{}\".format(i), dist.Normal(m_t, sigma), obs=y[i])\n        y_prev_prev = y_prev\n        y_prev = y_t\n\n\ndef run_inference(model, args, rng_key, y):\n    start = time.time()\n    sampler = numpyro.infer.NUTS(model)\n    mcmc = numpyro.infer.MCMC(\n        sampler,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, y=y)\n    mcmc.print_summary()\n    print(\"\\nMCMC elapsed time:\", time.time() - start)\n    return mcmc.get_samples()\n\n\ndef main(args):\n    # generate artifical dataset\n    num_data = args.num_data\n    rng_key = jax.random.PRNGKey(0)\n    t = jnp.arange(0, num_data)\n    y = jnp.sin(t) + random.normal(rng_key, (num_data,)) * 0.1\n\n    # do inference\n    if args.unroll_loop:\n        # slower\n        model = ar2_for_loop\n    else:\n        # faster\n        model = ar2_scan\n\n    run_inference(model, args, rng_key, y)\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"AR2 example\")\n    parser.add_argument(\"--num-data\", nargs=\"?\", default=142, type=int)\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    parser.add_argument(\n        \"--unroll-loop\",\n        action=\"store_true\",\n        help=\"whether to unroll for-loop (note: slower)\",\n    )\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}