{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Baseball Batting Average\n\nOriginal example from Pyro:\nhttps://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py\n\nExample has been adapted from [1]. It demonstrates how to do Bayesian inference using\nvarious MCMC kernels in Pyro (HMC, NUTS, SA), and use of some common inference utilities.\n\nAs in the Stan tutorial, this uses the small baseball dataset of Efron and Morris [2]\nto estimate players' batting average which is the fraction of times a player got a\nbase hit out of the number of times they went up at bat.\n\nThe dataset separates the initial 45 at-bats statistics from the remaining season.\nWe use the hits data from the initial 45 at-bats to estimate the batting average\nfor each player. We then use the remaining season's data to validate the predictions\nfrom our models.\n\nThree models are evaluated:\n\n    - Complete pooling model: The success probability of scoring a hit is shared\n      amongst all players.\n    - No pooling model: Each individual player's success probability is distinct and\n      there is no data sharing amongst players.\n    - Partial pooling model: A hierarchical model with partial data sharing.\n\nWe recommend Radford Neal's tutorial on HMC ([3]) to users who would like to get a\nmore comprehensive understanding of HMC and its variants, and to [4] for details on\nthe No U-Turn Sampler, which provides an efficient and automated way (i.e. limited\nhyper-parameters) of running HMC on different problems.\n\nNote that the Sample Adaptive (SA) kernel, which is implemented based on [5],\nrequires large `num_warmup` and `num_samples` (e.g. 15,000 and 300,000). So\nit is better to disable progress bar to avoid dispatching overhead.\n\n**References:**\n\n    1. Carpenter B. (2016), [\"Hierarchical Partial Pooling for Repeated Binary Trials\"](http://mc-stan.org/users/documentation/case-studies/pool-binary-trials.html/).\n    2. Efron B., Morris C. (1975), \"Data analysis using Stein's estimator and its\n       generalizations\", J. Amer. Statist. Assoc., 70, 311-319.\n    3. Neal, R. (2012), \"MCMC using Hamiltonian Dynamics\",\n       (https://arxiv.org/pdf/1206.1901.pdf)\n    4. Hoffman, M. D. and Gelman, A. (2014), \"The No-U-turn sampler: Adaptively setting\n       path lengths in Hamiltonian Monte Carlo\", (https://arxiv.org/abs/1111.4246)\n    5. Michael Zhu (2019), \"Sample Adaptive MCMC\",\n       (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nimport jax.numpy as jnp\nimport jax.random as random\nfrom jax.scipy.special import logsumexp\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import BASEBALL, load_dataset\nfrom numpyro.infer import HMC, MCMC, NUTS, SA, Predictive, log_likelihood\n\n\ndef fully_pooled(at_bats, hits=None):\n    r\"\"\"\n    Number of hits in $K$ at bats for each player has a Binomial\n    distribution with a common probability of success, $\\phi$.\n\n    :param (jnp.ndarray) at_bats: Number of at bats for each player.\n    :param (jnp.ndarray) hits: Number of hits for the given at bats.\n    :return: Number of hits predicted by the model.\n    \"\"\"\n    phi_prior = dist.Uniform(0, 1)\n    phi = numpyro.sample(\"phi\", phi_prior)\n    num_players = at_bats.shape[0]\n    with numpyro.plate(\"num_players\", num_players):\n        return numpyro.sample(\"obs\", dist.Binomial(at_bats, probs=phi), obs=hits)\n\n\ndef not_pooled(at_bats, hits=None):\n    r\"\"\"\n    Number of hits in $K$ at bats for each player has a Binomial\n    distribution with independent probability of success, $\\phi_i$.\n\n    :param (jnp.ndarray) at_bats: Number of at bats for each player.\n    :param (jnp.ndarray) hits: Number of hits for the given at bats.\n    :return: Number of hits predicted by the model.\n    \"\"\"\n    num_players = at_bats.shape[0]\n    with numpyro.plate(\"num_players\", num_players):\n        phi_prior = dist.Uniform(0, 1)\n        phi = numpyro.sample(\"phi\", phi_prior)\n        return numpyro.sample(\"obs\", dist.Binomial(at_bats, probs=phi), obs=hits)\n\n\ndef partially_pooled(at_bats, hits=None):\n    r\"\"\"\n    Number of hits has a Binomial distribution with independent\n    probability of success, $\\phi_i$. Each $\\phi_i$ follows a Beta\n    distribution with concentration parameters $c_1$ and $c_2$, where\n    $c_1 = m * kappa$, $c_2 = (1 - m) * kappa$, $m ~ Uniform(0, 1)$,\n    and $kappa ~ Pareto(1, 1.5)$.\n\n    :param (jnp.ndarray) at_bats: Number of at bats for each player.\n    :param (jnp.ndarray) hits: Number of hits for the given at bats.\n    :return: Number of hits predicted by the model.\n    \"\"\"\n    m = numpyro.sample(\"m\", dist.Uniform(0, 1))\n    kappa = numpyro.sample(\"kappa\", dist.Pareto(1, 1.5))\n    num_players = at_bats.shape[0]\n    with numpyro.plate(\"num_players\", num_players):\n        phi_prior = dist.Beta(m * kappa, (1 - m) * kappa)\n        phi = numpyro.sample(\"phi\", phi_prior)\n        return numpyro.sample(\"obs\", dist.Binomial(at_bats, probs=phi), obs=hits)\n\n\ndef partially_pooled_with_logit(at_bats, hits=None):\n    r\"\"\"\n    Number of hits has a Binomial distribution with a logit link function.\n    The logits $\\alpha$ for each player is normally distributed with the\n    mean and scale parameters sharing a common prior.\n\n    :param (jnp.ndarray) at_bats: Number of at bats for each player.\n    :param (jnp.ndarray) hits: Number of hits for the given at bats.\n    :return: Number of hits predicted by the model.\n    \"\"\"\n    loc = numpyro.sample(\"loc\", dist.Normal(-1, 1))\n    scale = numpyro.sample(\"scale\", dist.HalfCauchy(1))\n    num_players = at_bats.shape[0]\n    with numpyro.plate(\"num_players\", num_players):\n        alpha = numpyro.sample(\"alpha\", dist.Normal(loc, scale))\n        return numpyro.sample(\"obs\", dist.Binomial(at_bats, logits=alpha), obs=hits)\n\n\ndef run_inference(model, at_bats, hits, rng_key, args):\n    if args.algo == \"NUTS\":\n        kernel = NUTS(model)\n    elif args.algo == \"HMC\":\n        kernel = HMC(model)\n    elif args.algo == \"SA\":\n        kernel = SA(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False\n        if (\"NUMPYRO_SPHINXBUILD\" in os.environ or args.disable_progbar)\n        else True,\n    )\n    mcmc.run(rng_key, at_bats, hits)\n    return mcmc.get_samples()\n\n\ndef predict(model, at_bats, hits, z, rng_key, player_names, train=True):\n    header = model.__name__ + (\" - TRAIN\" if train else \" - TEST\")\n    predictions = Predictive(model, posterior_samples=z)(rng_key, at_bats)[\"obs\"]\n    print_results(\n        \"=\" * 30 + header + \"=\" * 30, predictions, player_names, at_bats, hits\n    )\n    if not train:\n        post_loglik = log_likelihood(model, z, at_bats, hits)[\"obs\"]\n        # computes expected log predictive density at each data point\n        exp_log_density = logsumexp(post_loglik, axis=0) - jnp.log(\n            jnp.shape(post_loglik)[0]\n        )\n        # reports log predictive density of all test points\n        print(\n            \"\\nLog pointwise predictive density: {:.2f}\\n\".format(exp_log_density.sum())\n        )\n\n\ndef print_results(header, preds, player_names, at_bats, hits):\n    columns = [\"\", \"At-bats\", \"ActualHits\", \"Pred(p25)\", \"Pred(p50)\", \"Pred(p75)\"]\n    header_format = \"{:>20} {:>10} {:>10} {:>10} {:>10} {:>10}\"\n    row_format = \"{:>20} {:>10.0f} {:>10.0f} {:>10.2f} {:>10.2f} {:>10.2f}\"\n    quantiles = jnp.quantile(preds, jnp.array([0.25, 0.5, 0.75]), axis=0)\n    print(\"\\n\", header, \"\\n\")\n    print(header_format.format(*columns))\n    for i, p in enumerate(player_names):\n        print(row_format.format(p, at_bats[i], hits[i], *quantiles[:, i]), \"\\n\")\n\n\ndef main(args):\n    _, fetch_train = load_dataset(BASEBALL, split=\"train\", shuffle=False)\n    train, player_names = fetch_train()\n    _, fetch_test = load_dataset(BASEBALL, split=\"test\", shuffle=False)\n    test, _ = fetch_test()\n    at_bats, hits = train[:, 0], train[:, 1]\n    season_at_bats, season_hits = test[:, 0], test[:, 1]\n    for i, model in enumerate(\n        (fully_pooled, not_pooled, partially_pooled, partially_pooled_with_logit)\n    ):\n        rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1))\n        zs = run_inference(model, at_bats, hits, rng_key, args)\n        predict(model, at_bats, hits, zs, rng_key_predict, player_names)\n        predict(\n            model,\n            season_at_bats,\n            season_hits,\n            zs,\n            rng_key_predict,\n            player_names,\n            train=False,\n        )\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Baseball batting average using MCMC\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=3000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1500, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\n        \"--algo\", default=\"NUTS\", type=str, help='whether to run \"HMC\", \"NUTS\", or \"SA\"'\n    )\n    parser.add_argument(\n        \"-dp\",\n        \"--disable-progbar\",\n        action=\"store_true\",\n        default=False,\n        help=\"whether to disable progress bar\",\n    )\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}