{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Modelling mortality over space and time\n\nThis example is adapted from [1]. The model in the paper estimates death rates for 6791 small\nareas in England for 19 age groups (0, 1-4, 5-9, 10-14, ..., 80-84, 85+ years) from 2002-19.\n\nWhen modelling mortality at high spatial resolutions, the number of deaths in each age group,\nspatial unit and year is small, meaning that death rates calculated from observed data have an\napparent variability which is larger than the true differences in the risk of dying. A Bayesian\nmultilevel modelling framework can overcome small number issues by sharing information across ages,\nspace and time to obtain smoothed death rates and capture the uncertainty in the estimate.\n\nAs well as a global intercept ($\\alpha_0$) and slope ($\\beta_0$), the model includes\nthe following effects:\n\n    - Age ($\\alpha_{2a}$, $\\beta_{2a}$). Each age group has a different intercept\n      and slope with a random-walk structure over age groups to allow for non-linear age associations.\n    - Space ($\\alpha_{1s}$). Each spatial unit has an intercept.\n      The spatial effects are defined by a nested hierarchy of random effects following the\n      administrative hierarchy of local government. The spatial term at the lower level unit is\n      centred on the spatial term of the higher level unit (e.g., $\\alpha_{1s_1}$) containing\n      that lower level unit.\n\nThe model also has a random walk effect over time ($\\pi_{t}$).\n\nDeath rates are linked to the death and population data using a binomial likelihood. The\nfull generative model of death rates is written as\n\n\\begin{align}:nowrap:\n\n    \\begin{align}\n        \\alpha_{1s_1} & \\sim \\text{Normal}(0,\\sigma_{\\alpha_{s_1}}^2) \\\\\n        \\alpha_{1s} & \\sim \\text{Normal}(\\alpha_{1s_1(s_2)},\\sigma_{\\alpha_{s_2}}^2) \\\\\n        \\alpha_{2a} & \\sim \\text{Normal}(\\alpha_{2,a-1},\\sigma_{\\alpha_a}^2) \\quad \\alpha_{2,0} = \\alpha_0 \\\\\n        \\beta_{2a} & \\sim \\text{Normal}(\\beta_{2,a-1},\\sigma_{\\beta_a}^2) \\quad \\beta_{2,0} = \\beta_0 \\\\\n        \\pi_{t} & \\sim \\text{Normal}(\\pi_{t-1},\\sigma_{\\pi}^2), \\quad \\pi_{0} = 0 \\\\\n        \\text{logit}(m_{ast}) & = \\alpha_{1s} + \\alpha_{2a} + \\beta_{2a} t + \\pi_{t}\n    \\end{align}\\end{align}\n\nwith the hyperpriors\n\n\\begin{align}:nowrap:\n\n    \\begin{align}\n        \\alpha_0 & \\sim \\text{Normal}(0,10), \\\\\n        \\beta_0 & \\sim \\text{Normal}(0,10), \\\\\n        \\sigma_i & \\sim \\text{Half-Normal}(1)\n    \\end{align}\\end{align}\n\nFurther detail about the model terms can be found in [1].\n\nThe NumPyro implementation below uses :class:`~numpyro.primitives.plate` notation to declare the batch\ndimensions of the age, space and time variables. This allows us to efficiently broadcast arrays\nin the likelihood.\n\nAs written above, the model includes a lot of centred random effects. The NUTS alogrithm benefits\nfrom a non-centred reparamatrisation to overcome difficult posterior geometries [2]. Rather than\nmanually writing out the non-centred parametrisation, we make use of the NumPyro's automatic\nreparametrisation in :class:`~numpyro.infer.reparam.LocScaleReparam`.\n\nDeath data at the spatial resolution in [1] are identifiable, so in this example we are using\nsimulated data. Compared to [1], the simulated data have fewer spatial units and a two-tier (rather than\nthree-tier) spatial hierarchy. There are still 19 age groups and 18 years as in the original study.\nThe data here have (event) dimensions of ``(19, 113, 18)`` (age, space, time).\n\nThe original implementation in nimble is at [3].\n\n**References**\n\n    1. Rashid, T., Bennett, J.E. et al. (2021).\n       Life expectancy and risk of death in 6791 communities in England from 2002\n       to 2019: high-resolution spatiotemporal analysis of civil registration data.\n       The Lancet Public Health, 6, e805 - e816.\n    2. Stan User's Guide. https://mc-stan.org/docs/2_28/stan-users-guide/reparameterization.html\n    3. Mortality using Bayesian hierarchical models. https://github.com/theorashid/mortality-statsmodel\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nimport numpy as np\n\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import MORTALITY, load_dataset\nfrom numpyro.infer import MCMC, NUTS\nfrom numpyro.infer.reparam import LocScaleReparam\n\n\ndef create_lookup(s1, s2):\n    \"\"\"\n    Create a map between s1 indices and unique s2 indices\n    \"\"\"\n    lookup = np.column_stack([s1, s2])\n    lookup = np.unique(lookup, axis=0)\n    lookup = lookup[lookup[:, 1].argsort()]\n    return lookup[:, 0]\n\n\nreparam_config = {\n    k: LocScaleReparam(0)\n    for k in [\n        \"alpha_s1\",\n        \"alpha_s2\",\n        \"alpha_age_drift\",\n        \"beta_age_drift\",\n        \"pi_drift\",\n    ]\n}\n\n\n@numpyro.handlers.reparam(config=reparam_config)\ndef model(age, space, time, lookup, population, deaths=None):\n    N_s1 = len(np.unique(lookup))\n    N_s2 = len(np.unique(space))\n    N_age = len(np.unique(age))\n    N_t = len(np.unique(time))\n    N = len(population)\n\n    # plates\n    age_plate = numpyro.plate(\"age_groups\", N_age, dim=-3)\n    space_plate = numpyro.plate(\"space\", N_s2, dim=-2)\n    year_plate = numpyro.plate(\"year\", N_t - 1, dim=-1)\n\n    # hyperparameters\n    sigma_alpha_s1 = numpyro.sample(\"sigma_alpha_s1\", dist.HalfNormal(1.0))\n    sigma_alpha_s2 = numpyro.sample(\"sigma_alpha_s2\", dist.HalfNormal(1.0))\n    sigma_alpha_age = numpyro.sample(\"sigma_alpha_age\", dist.HalfNormal(1.0))\n    sigma_beta_age = numpyro.sample(\"sigma_beta_age\", dist.HalfNormal(1.0))\n    sigma_pi = numpyro.sample(\"sigma_pi\", dist.HalfNormal(1.0))\n\n    # spatial hierarchy\n    with numpyro.plate(\"s1\", N_s1, dim=-2):\n        alpha_s1 = numpyro.sample(\"alpha_s1\", dist.Normal(0, sigma_alpha_s1))\n    with space_plate:\n        alpha_s2 = numpyro.sample(\n            \"alpha_s2\", dist.Normal(alpha_s1[lookup], sigma_alpha_s2)\n        )\n\n    # age\n    with age_plate:\n        alpha_age_drift_scale = jnp.pad(\n            jnp.broadcast_to(sigma_alpha_age, N_age - 1),\n            (1, 0),\n            constant_values=10.0,  # pad so first term is alpha0, prior N(0, 10)\n        )[:, jnp.newaxis, jnp.newaxis]\n        alpha_age_drift = numpyro.sample(\n            \"alpha_age_drift\", dist.Normal(0, alpha_age_drift_scale)\n        )\n        alpha_age = jnp.cumsum(alpha_age_drift, -3)\n\n        beta_age_drift_scale = jnp.pad(\n            jnp.broadcast_to(sigma_beta_age, N_age - 1), (1, 0), constant_values=10.0\n        )[:, jnp.newaxis, jnp.newaxis]\n        beta_age_drift = numpyro.sample(\n            \"beta_age_drift\", dist.Normal(0, beta_age_drift_scale)\n        )\n        beta_age = jnp.cumsum(beta_age_drift, -3)\n        beta_age_cum = jnp.outer(beta_age, jnp.arange(N_t))[:, jnp.newaxis, :]\n\n    # random walk over time\n    with year_plate:\n        pi_drift = numpyro.sample(\"pi_drift\", dist.Normal(0, sigma_pi))\n        pi = jnp.pad(jnp.cumsum(pi_drift, -1), (1, 0))\n\n    # likelihood\n    latent_rate = alpha_age + beta_age_cum + alpha_s2 + pi\n    with numpyro.plate(\"N\", N):\n        mu_logit = latent_rate[age, space, time]\n        numpyro.sample(\"deaths\", dist.Binomial(population, logits=mu_logit), obs=deaths)\n\n\ndef print_model_shape(model, age, space, time, lookup, population):\n    with numpyro.handlers.seed(rng_seed=1):\n        trace = numpyro.handlers.trace(model).get_trace(\n            age=age,\n            space=space,\n            time=time,\n            lookup=lookup,\n            population=population,\n        )\n    print(numpyro.util.format_shapes(trace))\n\n\ndef run_inference(model, age, space, time, lookup, population, deaths, rng_key, args):\n    kernel = NUTS(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, age, space, time, lookup, population, deaths)\n    mcmc.print_summary()\n    return mcmc.get_samples()\n\n\ndef main(args):\n    print(\"Fetching simulated data...\")\n    _, fetch = load_dataset(MORTALITY, shuffle=False)\n    a, s1, s2, t, deaths, population = fetch()\n\n    lookup = create_lookup(s1, s2)\n\n    print(\"Model shape:\")\n    print_model_shape(model, a, s2, t, lookup, population)\n\n    print(\"Starting inference...\")\n    rng_key = random.PRNGKey(args.rng_seed)\n    run_inference(model, a, s2, t, lookup, population, deaths, rng_key, args)\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n\n    parser = argparse.ArgumentParser(description=\"Mortality regression model\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=500, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=200, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    parser.add_argument(\n        \"--rng_seed\", default=21, type=int, help=\"random number generator seed\"\n    )\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.enable_x64()\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}