{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Generalized Linear Mixed Models\n\nThe UCBadmit data is sourced from the study [1] of gender biased in graduate admissions at\nUC Berkeley in Fall 1973:\n\n.. table:: UCBadmit dataset\n   :align: center\n\n   ====== ====== ============== =======\n    dept   male   applications   admit\n   ====== ====== ============== =======\n     0      1         825         512\n     0      0         108          89\n     1      1         560         353\n     1      0          25          17\n     2      1         325         120\n     2      0         593         202\n     3      1         417         138\n     3      0         375         131\n     4      1         191          53\n     4      0         393          94\n     5      1         373          22\n     5      0         341          24\n   ====== ====== ============== =======\n\nThis example replicates the multilevel model `m_glmm5` at [3], which is used to evaluate whether\nthe data contain evidence of gender biased in admissions across departments. This is a form of\nGeneralized Linear Mixed Models for binomial regression problem, which models\n\n    - varying intercepts across departments,\n    - varying slopes (or the effects of being male) across departments,\n    - correlation between intercepts and slopes,\n\nand uses non-centered parameterization (or whitening).\n\nA more comprehensive explanation for binomial regression and non-centered parameterization can be\nfound in Chapter 10 (Counting and Classification) and Chapter 13 (Adventures in Covariance) of [2].\n\n**References:**\n\n    1. Bickel, P. J., Hammel, E. A., and O'Connell, J. W. (1975), \"Sex Bias in Graduate Admissions:\n       Data from Berkeley\", Science, 187(4175), 398-404.\n    2. McElreath, R. (2018), \"Statistical Rethinking: A Bayesian Course with Examples in R and Stan\",\n       Chapman and Hall/CRC.\n    3. https://github.com/rmcelreath/rethinking/tree/Experimental#multilevel-model-formulas\n\n<img src=\"file://../_static/img/examples/ucbadmit.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom jax import random\nimport jax.numpy as jnp\nfrom jax.scipy.special import expit\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import UCBADMIT, load_dataset\nfrom numpyro.infer import MCMC, NUTS, Predictive\n\n\ndef glmm(dept, male, applications, admit=None):\n    v_mu = numpyro.sample(\"v_mu\", dist.Normal(0, jnp.array([4.0, 1.0])))\n\n    sigma = numpyro.sample(\"sigma\", dist.HalfNormal(jnp.ones(2)))\n    L_Rho = numpyro.sample(\"L_Rho\", dist.LKJCholesky(2, concentration=2))\n    scale_tril = sigma[..., jnp.newaxis] * L_Rho\n    # non-centered parameterization\n    num_dept = len(np.unique(dept))\n    z = numpyro.sample(\"z\", dist.Normal(jnp.zeros((num_dept, 2)), 1))\n    v = jnp.dot(scale_tril, z.T).T\n\n    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male\n    if admit is None:\n        # we use a Delta site to record probs for predictive distribution\n        probs = expit(logits)\n        numpyro.sample(\"probs\", dist.Delta(probs), obs=probs)\n    numpyro.sample(\"admit\", dist.Binomial(applications, logits=logits), obs=admit)\n\n\ndef run_inference(dept, male, applications, admit, rng_key, args):\n    kernel = NUTS(glmm)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, dept, male, applications, admit)\n    return mcmc.get_samples()\n\n\ndef print_results(header, preds, dept, male, probs):\n    columns = [\"Dept\", \"Male\", \"ActualProb\", \"Pred(p25)\", \"Pred(p50)\", \"Pred(p75)\"]\n    header_format = \"{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}\"\n    row_format = \"{:>10.0f} {:>10.0f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}\"\n    quantiles = jnp.quantile(preds, jnp.array([0.25, 0.5, 0.75]), axis=0)\n    print(\"\\n\", header, \"\\n\")\n    print(header_format.format(*columns))\n    for i in range(len(dept)):\n        print(row_format.format(dept[i], male[i], probs[i], *quantiles[:, i]), \"\\n\")\n\n\ndef main(args):\n    _, fetch_train = load_dataset(UCBADMIT, split=\"train\", shuffle=False)\n    dept, male, applications, admit = fetch_train()\n    rng_key, rng_key_predict = random.split(random.PRNGKey(1))\n    zs = run_inference(dept, male, applications, admit, rng_key, args)\n    pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)[\n        \"probs\"\n    ]\n    header = \"=\" * 30 + \"glmm - TRAIN\" + \"=\" * 30\n    print_results(header, pred_probs, dept, male, admit / applications)\n\n    # make plots\n    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)\n\n    ax.plot(range(1, 13), admit / applications, \"o\", ms=7, label=\"actual rate\")\n    ax.errorbar(\n        range(1, 13),\n        jnp.mean(pred_probs, 0),\n        jnp.std(pred_probs, 0),\n        fmt=\"o\",\n        c=\"k\",\n        mfc=\"none\",\n        ms=7,\n        elinewidth=1,\n        label=r\"mean $\\pm$ std\",\n    )\n    ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), \"k+\")\n    ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), \"k+\")\n    ax.set(\n        xlabel=\"cases\",\n        ylabel=\"admit rate\",\n        title=\"Posterior Predictive Check with 90% CI\",\n    )\n    ax.legend()\n\n    plt.savefig(\"ucbadmit_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(\n        description=\"UCBadmit gender discrimination using HMC\"\n    )\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=2000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=500, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}