{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: CJS Capture-Recapture Model for Ecological Data\n\nThis example is ported from [8].\n\nWe show how to implement several variants of the Cormack-Jolly-Seber (CJS)\n[4, 5, 6] model used in ecology to analyze animal capture-recapture data.\nFor a discussion of these models see reference [1].\n\nWe make use of two datasets:\n\n    - the European Dipper (Cinclus cinclus) data from reference [2]\n      (this is Norway's national bird).\n    - the meadow voles data from reference [3].\n\nCompare to the Stan implementations in [7].\n\n**References**\n\n    1. Kery, M., & Schaub, M. (2011). Bayesian population analysis using\n       WinBUGS: a hierarchical perspective. Academic Press.\n    2. Lebreton, J.D., Burnham, K.P., Clobert, J., & Anderson, D.R. (1992).\n       Modeling survival and testing biological hypotheses using marked animals:\n       a unified approach with case studies. Ecological monographs, 62(1), 67-118.\n    3. Nichols, Pollock, Hines (1984) The use of a robust capture-recapture design\n       in small mammal population studies: A field example with Microtus pennsylvanicus.\n       Acta Theriologica 29:357-365.\n    4. Cormack, R.M., 1964. Estimates of survival from the sighting of marked animals.\n       Biometrika 51, 429-438.\n    5. Jolly, G.M., 1965. Explicit estimates from capture-recapture data with both death\n       and immigration-stochastic model. Biometrika 52, 225-247.\n    6. Seber, G.A.F., 1965. A note on the multiple recapture census. Biometrika 52, 249-259.\n    7. https://github.com/stan-dev/example-models/tree/master/BPA/Ch.07\n    8. http://pyro.ai/examples/capture_recapture.html\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nfrom jax import random\nimport jax.numpy as jnp\nfrom jax.scipy.special import expit, logit\n\nimport numpyro\nfrom numpyro import handlers\nfrom numpyro.contrib.control_flow import scan\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import DIPPER_VOLE, load_dataset\nfrom numpyro.infer import HMC, MCMC, NUTS\nfrom numpyro.infer.reparam import LocScaleReparam"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Our first and simplest CJS model variant only has two continuous\n(scalar) latent random variables: i) the survival probability phi;\nand ii) the recapture probability rho. These are treated as fixed\neffects with no temporal or individual/group variation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def model_1(capture_history, sex):\n    N, T = capture_history.shape\n    phi = numpyro.sample(\"phi\", dist.Uniform(0.0, 1.0))  # survival probability\n    rho = numpyro.sample(\"rho\", dist.Uniform(0.0, 1.0))  # recapture probability\n\n    def transition_fn(carry, y):\n        first_capture_mask, z = carry\n        with numpyro.plate(\"animals\", N, dim=-1):\n            with handlers.mask(mask=first_capture_mask):\n                mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)\n                # NumPyro exactly sums out the discrete states z_t.\n                z = numpyro.sample(\n                    \"z\",\n                    dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                mu_y_t = rho * z\n                numpyro.sample(\n                    \"y\", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y\n                )\n\n        first_capture_mask = first_capture_mask | y.astype(bool)\n        return (first_capture_mask, z), None\n\n    z = jnp.ones(N, dtype=jnp.int32)\n    # we use this mask to eliminate extraneous log probabilities\n    # that arise for a given individual before its first capture.\n    first_capture_mask = capture_history[:, 0].astype(bool)\n    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it\n    scan(\n        transition_fn,\n        (first_capture_mask, z),\n        jnp.swapaxes(capture_history[:, 1:], 0, 1),\n    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In our second model variant there is a time-varying survival probability phi_t for\nT-1 of the T time periods of the capture data; each phi_t is treated as a fixed effect.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def model_2(capture_history, sex):\n    N, T = capture_history.shape\n    rho = numpyro.sample(\"rho\", dist.Uniform(0.0, 1.0))  # recapture probability\n\n    def transition_fn(carry, y):\n        first_capture_mask, z = carry\n        # note that phi_t needs to be outside the plate, since\n        # phi_t is shared across all N individuals\n        phi_t = numpyro.sample(\"phi\", dist.Uniform(0.0, 1.0))\n\n        with numpyro.plate(\"animals\", N, dim=-1):\n            with handlers.mask(mask=first_capture_mask):\n                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)\n                # NumPyro exactly sums out the discrete states z_t.\n                z = numpyro.sample(\n                    \"z\",\n                    dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                mu_y_t = rho * z\n                numpyro.sample(\n                    \"y\", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y\n                )\n\n        first_capture_mask = first_capture_mask | y.astype(bool)\n        return (first_capture_mask, z), None\n\n    z = jnp.ones(N, dtype=jnp.int32)\n    # we use this mask to eliminate extraneous log probabilities\n    # that arise for a given individual before its first capture.\n    first_capture_mask = capture_history[:, 0].astype(bool)\n    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it\n    scan(\n        transition_fn,\n        (first_capture_mask, z),\n        jnp.swapaxes(capture_history[:, 1:], 0, 1),\n    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In our third model variant there is a survival probability phi_t for T-1\nof the T time periods of the capture data (just like in model_2), but here\neach phi_t is treated as a random effect.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def model_3(capture_history, sex):\n    N, T = capture_history.shape\n    phi_mean = numpyro.sample(\n        \"phi_mean\", dist.Uniform(0.0, 1.0)\n    )  # mean survival probability\n    phi_logit_mean = logit(phi_mean)\n    # controls temporal variability of survival probability\n    phi_sigma = numpyro.sample(\"phi_sigma\", dist.Uniform(0.0, 10.0))\n    rho = numpyro.sample(\"rho\", dist.Uniform(0.0, 1.0))  # recapture probability\n\n    def transition_fn(carry, y):\n        first_capture_mask, z = carry\n        with handlers.reparam(config={\"phi_logit\": LocScaleReparam(0)}):\n            phi_logit_t = numpyro.sample(\n                \"phi_logit\", dist.Normal(phi_logit_mean, phi_sigma)\n            )\n        phi_t = expit(phi_logit_t)\n        with numpyro.plate(\"animals\", N, dim=-1):\n            with handlers.mask(mask=first_capture_mask):\n                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)\n                # NumPyro exactly sums out the discrete states z_t.\n                z = numpyro.sample(\n                    \"z\",\n                    dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                mu_y_t = rho * z\n                numpyro.sample(\n                    \"y\", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y\n                )\n\n        first_capture_mask = first_capture_mask | y.astype(bool)\n        return (first_capture_mask, z), None\n\n    z = jnp.ones(N, dtype=jnp.int32)\n    # we use this mask to eliminate extraneous log probabilities\n    # that arise for a given individual before its first capture.\n    first_capture_mask = capture_history[:, 0].astype(bool)\n    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it\n    scan(\n        transition_fn,\n        (first_capture_mask, z),\n        jnp.swapaxes(capture_history[:, 1:], 0, 1),\n    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In our fourth model variant we include group-level fixed effects\nfor sex (male, female).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def model_4(capture_history, sex):\n    N, T = capture_history.shape\n    # survival probabilities for males/females\n    phi_male = numpyro.sample(\"phi_male\", dist.Uniform(0.0, 1.0))\n    phi_female = numpyro.sample(\"phi_female\", dist.Uniform(0.0, 1.0))\n    # we construct a N-dimensional vector that contains the appropriate\n    # phi for each individual given its sex (female = 0, male = 1)\n    phi = sex * phi_male + (1.0 - sex) * phi_female\n    rho = numpyro.sample(\"rho\", dist.Uniform(0.0, 1.0))  # recapture probability\n\n    def transition_fn(carry, y):\n        first_capture_mask, z = carry\n        with numpyro.plate(\"animals\", N, dim=-1):\n            with handlers.mask(mask=first_capture_mask):\n                mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)\n                # NumPyro exactly sums out the discrete states z_t.\n                z = numpyro.sample(\n                    \"z\",\n                    dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                mu_y_t = rho * z\n                numpyro.sample(\n                    \"y\", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y\n                )\n\n        first_capture_mask = first_capture_mask | y.astype(bool)\n        return (first_capture_mask, z), None\n\n    z = jnp.ones(N, dtype=jnp.int32)\n    # we use this mask to eliminate extraneous log probabilities\n    # that arise for a given individual before its first capture.\n    first_capture_mask = capture_history[:, 0].astype(bool)\n    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it\n    scan(\n        transition_fn,\n        (first_capture_mask, z),\n        jnp.swapaxes(capture_history[:, 1:], 0, 1),\n    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In our final model variant we include both fixed group effects and fixed\ntime effects for the survival probability phi:\nlogit(phi_t) = beta_group + gamma_t\nWe need to take care that the model is not overparameterized; to do this\nwe effectively let a single scalar beta encode the difference in male\nand female survival probabilities.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def model_5(capture_history, sex):\n    N, T = capture_history.shape\n\n    # phi_beta controls the survival probability differential\n    # for males versus females (in logit space)\n    phi_beta = numpyro.sample(\"phi_beta\", dist.Normal(0.0, 10.0))\n    phi_beta = sex * phi_beta\n    rho = numpyro.sample(\"rho\", dist.Uniform(0.0, 1.0))  # recapture probability\n\n    def transition_fn(carry, y):\n        first_capture_mask, z = carry\n        phi_gamma_t = numpyro.sample(\"phi_gamma\", dist.Normal(0.0, 10.0))\n        phi_t = expit(phi_beta + phi_gamma_t)\n        with numpyro.plate(\"animals\", N, dim=-1):\n            with handlers.mask(mask=first_capture_mask):\n                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)\n                # NumPyro exactly sums out the discrete states z_t.\n                z = numpyro.sample(\n                    \"z\",\n                    dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                mu_y_t = rho * z\n                numpyro.sample(\n                    \"y\", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y\n                )\n\n        first_capture_mask = first_capture_mask | y.astype(bool)\n        return (first_capture_mask, z), None\n\n    z = jnp.ones(N, dtype=jnp.int32)\n    # we use this mask to eliminate extraneous log probabilities\n    # that arise for a given individual before its first capture.\n    first_capture_mask = capture_history[:, 0].astype(bool)\n    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it\n    scan(\n        transition_fn,\n        (first_capture_mask, z),\n        jnp.swapaxes(capture_history[:, 1:], 0, 1),\n    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Do inference\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "models = {\n    name[len(\"model_\") :]: model\n    for name, model in globals().items()\n    if name.startswith(\"model_\")\n}\n\n\ndef run_inference(model, capture_history, sex, rng_key, args):\n    if args.algo == \"NUTS\":\n        kernel = NUTS(model)\n    elif args.algo == \"HMC\":\n        kernel = HMC(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, capture_history, sex)\n    mcmc.print_summary()\n    return mcmc.get_samples()\n\n\ndef main(args):\n    # load data\n    if args.dataset == \"dipper\":\n        capture_history, sex = load_dataset(DIPPER_VOLE, split=\"dipper\", shuffle=False)[\n            1\n        ]()\n    elif args.dataset == \"vole\":\n        if args.model in [\"4\", \"5\"]:\n            raise ValueError(\n                \"Cannot run model_{} on meadow voles data, since we lack sex \"\n                \"information for these animals.\".format(args.model)\n            )\n        (capture_history,) = load_dataset(DIPPER_VOLE, split=\"vole\", shuffle=False)[1]()\n        sex = None\n    else:\n        raise ValueError(\"Available datasets are 'dipper' and 'vole'.\")\n\n    N, T = capture_history.shape\n    print(\n        \"Loaded {} capture history for {} individuals collected over {} time periods.\".format(\n            args.dataset, N, T\n        )\n    )\n\n    model = models[args.model]\n    rng_key = random.PRNGKey(args.rng_seed)\n    run_inference(model, capture_history, sex, rng_key, args)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"CJS capture-recapture model for ecological data\"\n    )\n    parser.add_argument(\n        \"-m\",\n        \"--model\",\n        default=\"1\",\n        type=str,\n        help=\"one of: {}\".format(\", \".join(sorted(models.keys()))),\n    )\n    parser.add_argument(\"-d\", \"--dataset\", default=\"dipper\", type=str)\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\n        \"--rng_seed\", default=0, type=int, help=\"random number generator seed\"\n    )\n    parser.add_argument(\n        \"--algo\", default=\"NUTS\", type=str, help='whether to run \"NUTS\" or \"HMC\"'\n    )\n    args = parser.parse_args()\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}