{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Predator-Prey Model\n\nThis example replicates the great case study [1], which leverages the Lotka-Volterra\nequation [2] to describe the dynamics of Canada lynx (predator) and snowshoe hare\n(prey) populations. We will use the dataset obtained from [3] and run MCMC to get\ninferences about parameters of the differential equation governing the dynamics.\n\n**References:**\n\n    1. Bob Carpenter (2018), [\"Predator-Prey Population Dynamics: the Lotka-Volterra model in Stan\"](https://mc-stan.org/users/documentation/case-studies/lotka-volterra-predator-prey.html/).\n    2. https://en.wikipedia.org/wiki/Lotka-Volterra_equations\n    3. http://people.whitman.edu/~hundledr/courses/M250F03/M250.html\n\n<img src=\"file://../_static/img/examples/ode.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nimport matplotlib\nimport matplotlib.pyplot as plt\n\nfrom jax.experimental.ode import odeint\nimport jax.numpy as jnp\nfrom jax.random import PRNGKey\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import LYNXHARE, load_dataset\nfrom numpyro.infer import MCMC, NUTS, Predictive\n\nmatplotlib.use(\"Agg\")  # noqa: E402\n\n\ndef dz_dt(z, t, theta):\n    \"\"\"\n    Lotka\u2013Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`\n    describes the interaction of two species.\n    \"\"\"\n    u = z[0]\n    v = z[1]\n    alpha, beta, gamma, delta = (\n        theta[..., 0],\n        theta[..., 1],\n        theta[..., 2],\n        theta[..., 3],\n    )\n    du_dt = (alpha - beta * v) * u\n    dv_dt = (-gamma + delta * u) * v\n    return jnp.stack([du_dt, dv_dt])\n\n\ndef model(N, y=None):\n    \"\"\"\n    :param int N: number of measurement times\n    :param numpy.ndarray y: measured populations with shape (N, 2)\n    \"\"\"\n    # initial population\n    z_init = numpyro.sample(\"z_init\", dist.LogNormal(jnp.log(10), 1).expand([2]))\n    # measurement times\n    ts = jnp.arange(float(N))\n    # parameters alpha, beta, gamma, delta of dz_dt\n    theta = numpyro.sample(\n        \"theta\",\n        dist.TruncatedNormal(\n            low=0.0,\n            loc=jnp.array([1.0, 0.05, 1.0, 0.05]),\n            scale=jnp.array([0.5, 0.05, 0.5, 0.05]),\n        ),\n    )\n    # integrate dz/dt, the result will have shape N x 2\n    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-6, atol=1e-5, mxstep=1000)\n    # measurement errors\n    sigma = numpyro.sample(\"sigma\", dist.LogNormal(-1, 1).expand([2]))\n    # measured populations\n    numpyro.sample(\"y\", dist.LogNormal(jnp.log(z), sigma), obs=y)\n\n\ndef main(args):\n    _, fetch = load_dataset(LYNXHARE, shuffle=False)\n    year, data = fetch()  # data is in hare -> lynx order\n\n    # use dense_mass for better mixing rate\n    mcmc = MCMC(\n        NUTS(model, dense_mass=True),\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(PRNGKey(1), N=data.shape[0], y=data)\n    mcmc.print_summary()\n\n    # predict populations\n    pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])[\"y\"]\n    mu = jnp.mean(pop_pred, 0)\n    pi = jnp.percentile(pop_pred, jnp.array([10, 90]), 0)\n    plt.figure(figsize=(8, 6), constrained_layout=True)\n    plt.plot(year, data[:, 0], \"ko\", mfc=\"none\", ms=4, label=\"true hare\", alpha=0.67)\n    plt.plot(year, data[:, 1], \"bx\", label=\"true lynx\")\n    plt.plot(year, mu[:, 0], \"k-.\", label=\"pred hare\", lw=1, alpha=0.67)\n    plt.plot(year, mu[:, 1], \"b--\", label=\"pred lynx\")\n    plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color=\"k\", alpha=0.2)\n    plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color=\"b\", alpha=0.3)\n    plt.gca().set(ylim=(0, 160), xlabel=\"year\", ylabel=\"population (in thousands)\")\n    plt.title(\"Posterior predictive (80% CI) with predator-prey pattern.\")\n    plt.legend()\n\n    plt.savefig(\"ode_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Predator-Prey Model\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}