{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Zero-Inflated Poisson regression model\n\nIn this example, we model and predict how many fish are caught by visitors to a state park.\nMany groups of visitors catch zero fish, either because they did not fish at all or because\nthey were unlucky. We would like to explicitly model this bimodal behavior (zero versus non-zero)\nand ascertain which variables contribute to each behavior.\n\nWe answer this question by fitting a zero-inflated poisson regression model. We use MAP,\nVI and MCMC as estimation methods. Finally, from the MCMC samples, we identify the variables that\ncontribute to the zero and non-zero components of the zero-inflated poisson likelihood.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\nimport random\n\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nfrom sklearn.metrics import mean_squared_error\n\nimport jax.numpy as jnp\nfrom jax.random import PRNGKey\nimport jax.scipy as jsp\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, autoguide\n\nmatplotlib.use(\"Agg\")  # noqa: E402\n\n\ndef set_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n\n\ndef model(X, Y):\n    D_X = X.shape[1]\n    b1 = numpyro.sample(\"b1\", dist.Normal(0.0, 1.0).expand([D_X]).to_event(1))\n    b2 = numpyro.sample(\"b2\", dist.Normal(0.0, 1.0).expand([D_X]).to_event(1))\n\n    q = jsp.special.expit(jnp.dot(X, b1[:, None])).reshape(-1)\n    lam = jnp.exp(jnp.dot(X, b2[:, None]).reshape(-1))\n\n    with numpyro.plate(\"obs\", X.shape[0]):\n        numpyro.sample(\"Y\", dist.ZeroInflatedPoisson(gate=q, rate=lam), obs=Y)\n\n\ndef run_mcmc(model, args, X, Y):\n    kernel = NUTS(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(PRNGKey(1), X, Y)\n    mcmc.print_summary()\n    return mcmc.get_samples()\n\n\ndef run_svi(model, guide_family, args, X, Y):\n    if guide_family == \"AutoDelta\":\n        guide = autoguide.AutoDelta(model)\n    elif guide_family == \"AutoDiagonalNormal\":\n        guide = autoguide.AutoDiagonalNormal(model)\n\n    optimizer = numpyro.optim.Adam(0.001)\n    svi = SVI(model, guide, optimizer, Trace_ELBO())\n    svi_results = svi.run(PRNGKey(1), args.maxiter, X=X, Y=Y)\n    params = svi_results.params\n\n    return params, guide\n\n\ndef main(args):\n    set_seed(args.seed)\n\n    # prepare dataset\n    df = pd.read_stata(\"http://www.stata-press.com/data/r11/fish.dta\")\n    df[\"intercept\"] = 1\n    cols = [\"livebait\", \"camper\", \"persons\", \"child\", \"intercept\"]\n\n    mask = np.random.randn(len(df)) < args.train_size\n    df_train = df[mask]\n    df_test = df[~mask]\n    X_train = jnp.asarray(df_train[cols].values)\n    y_train = jnp.asarray(df_train[\"count\"].values)\n    X_test = jnp.asarray(df_test[cols].values)\n    y_test = jnp.asarray(df_test[\"count\"].values)\n\n    print(\"run MAP.\")\n    map_params, map_guide = run_svi(model, \"AutoDelta\", args, X_train, y_train)\n\n    print(\"run VI.\")\n    vi_params, vi_guide = run_svi(model, \"AutoDiagonalNormal\", args, X_train, y_train)\n\n    print(\"run MCMC.\")\n    posterior_samples = run_mcmc(model, args, X_train, y_train)\n\n    # evaluation\n\n    def svi_predict(model, guide, params, args, X):\n        predictive = Predictive(\n            model=model, guide=guide, params=params, num_samples=args.num_samples\n        )\n        predictions = predictive(PRNGKey(1), X=X, Y=None)\n        svi_predictions = jnp.rint(predictions[\"Y\"].mean(0))\n        return svi_predictions\n\n    map_predictions = svi_predict(model, map_guide, map_params, args, X_test)\n    vi_predictions = svi_predict(model, vi_guide, vi_params, args, X_test)\n\n    predictive = Predictive(model, posterior_samples=posterior_samples)\n    predictions = predictive(PRNGKey(1), X=X_test, Y=None)\n    mcmc_predictions = jnp.rint(predictions[\"Y\"].mean(0))\n\n    print(\n        \"MAP RMSE: \",\n        mean_squared_error(y_test.to_py(), map_predictions.to_py(), squared=False),\n    )\n    print(\n        \"VI RMSE: \",\n        mean_squared_error(y_test.to_py(), vi_predictions.to_py(), squared=False),\n    )\n    print(\n        \"MCMC RMSE: \",\n        mean_squared_error(y_test.to_py(), mcmc_predictions.to_py(), squared=False),\n    )\n\n    # make plot\n    fig, axes = plt.subplots(2, 1, figsize=(6, 6), constrained_layout=True)\n\n    def add_fig(var_name, title, ax):\n        ax.set_title(title)\n        ax.violinplot(\n            [posterior_samples[var_name][:, i].to_py() for i in range(len(cols))]\n        )\n        ax.set_xticks(np.arange(1, len(cols) + 1))\n        ax.set_xticklabels(cols, rotation=45, fontsize=10)\n\n    add_fig(\"b1\", \"Coefficients for probability of catching fish\", axes[0])\n    add_fig(\"b2\", \"Coefficients for the number of fish caught\", axes[1])\n\n    plt.savefig(\"zip_fish.png\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"Zero-Inflated Poisson Regression\")\n    parser.add_argument(\"--seed\", nargs=\"?\", default=42, type=int)\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=2000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--num-data\", nargs=\"?\", default=100, type=int)\n    parser.add_argument(\"--maxiter\", nargs=\"?\", default=5000, type=int)\n    parser.add_argument(\"--train-size\", nargs=\"?\", default=0.8, type=float)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}