
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/capture_recapture.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_capture_recapture.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_capture_recapture.py:


Example: CJS Capture-Recapture Model for Ecological Data
========================================================

This example is ported from [8].

We show how to implement several variants of the Cormack-Jolly-Seber (CJS)
[4, 5, 6] model used in ecology to analyze animal capture-recapture data.
For a discussion of these models see reference [1].

We make use of two datasets:

    - the European Dipper (Cinclus cinclus) data from reference [2]
      (this is Norway's national bird).
    - the meadow voles data from reference [3].

Compare to the Stan implementations in [7].

**References**

    1. Kery, M., & Schaub, M. (2011). Bayesian population analysis using
       WinBUGS: a hierarchical perspective. Academic Press.
    2. Lebreton, J.D., Burnham, K.P., Clobert, J., & Anderson, D.R. (1992).
       Modeling survival and testing biological hypotheses using marked animals:
       a unified approach with case studies. Ecological monographs, 62(1), 67-118.
    3. Nichols, Pollock, Hines (1984) The use of a robust capture-recapture design
       in small mammal population studies: A field example with Microtus pennsylvanicus.
       Acta Theriologica 29:357-365.
    4. Cormack, R.M., 1964. Estimates of survival from the sighting of marked animals.
       Biometrika 51, 429-438.
    5. Jolly, G.M., 1965. Explicit estimates from capture-recapture data with both death
       and immigration-stochastic model. Biometrika 52, 225-247.
    6. Seber, G.A.F., 1965. A note on the multiple recapture census. Biometrika 52, 249-259.
    7. https://github.com/stan-dev/example-models/tree/master/BPA/Ch.07
    8. http://pyro.ai/examples/capture_recapture.html

.. GENERATED FROM PYTHON SOURCE LINES 41-57

.. code-block:: default


    import argparse
    import os

    from jax import random
    import jax.numpy as jnp
    from jax.scipy.special import expit, logit

    import numpyro
    from numpyro import handlers
    from numpyro.contrib.control_flow import scan
    import numpyro.distributions as dist
    from numpyro.examples.datasets import DIPPER_VOLE, load_dataset
    from numpyro.infer import HMC, MCMC, NUTS
    from numpyro.infer.reparam import LocScaleReparam


.. GENERATED FROM PYTHON SOURCE LINES 58-62

Our first and simplest CJS model variant only has two continuous
(scalar) latent random variables: i) the survival probability phi;
and ii) the recapture probability rho. These are treated as fixed
effects with no temporal or individual/group variation.

.. GENERATED FROM PYTHON SOURCE LINES 62-100

.. code-block:: default



    def model_1(capture_history, sex):
        N, T = capture_history.shape
        phi = numpyro.sample("phi", dist.Uniform(0.0, 1.0))  # survival probability
        rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

        def transition_fn(carry, y):
            first_capture_mask, z = carry
            with numpyro.plate("animals", N, dim=-1):
                with handlers.mask(mask=first_capture_mask):
                    mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
                    # NumPyro exactly sums out the discrete states z_t.
                    z = numpyro.sample(
                        "z",
                        dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
                        infer={"enumerate": "parallel"},
                    )
                    mu_y_t = rho * z
                    numpyro.sample(
                        "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                    )

            first_capture_mask = first_capture_mask | y.astype(bool)
            return (first_capture_mask, z), None

        z = jnp.ones(N, dtype=jnp.int32)
        # we use this mask to eliminate extraneous log probabilities
        # that arise for a given individual before its first capture.
        first_capture_mask = capture_history[:, 0].astype(bool)
        # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
        scan(
            transition_fn,
            (first_capture_mask, z),
            jnp.swapaxes(capture_history[:, 1:], 0, 1),
        )



.. GENERATED FROM PYTHON SOURCE LINES 101-103

In our second model variant there is a time-varying survival probability phi_t for
T-1 of the T time periods of the capture data; each phi_t is treated as a fixed effect.

.. GENERATED FROM PYTHON SOURCE LINES 103-144

.. code-block:: default



    def model_2(capture_history, sex):
        N, T = capture_history.shape
        rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

        def transition_fn(carry, y):
            first_capture_mask, z = carry
            # note that phi_t needs to be outside the plate, since
            # phi_t is shared across all N individuals
            phi_t = numpyro.sample("phi", dist.Uniform(0.0, 1.0))

            with numpyro.plate("animals", N, dim=-1):
                with handlers.mask(mask=first_capture_mask):
                    mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                    # NumPyro exactly sums out the discrete states z_t.
                    z = numpyro.sample(
                        "z",
                        dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
                        infer={"enumerate": "parallel"},
                    )
                    mu_y_t = rho * z
                    numpyro.sample(
                        "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                    )

            first_capture_mask = first_capture_mask | y.astype(bool)
            return (first_capture_mask, z), None

        z = jnp.ones(N, dtype=jnp.int32)
        # we use this mask to eliminate extraneous log probabilities
        # that arise for a given individual before its first capture.
        first_capture_mask = capture_history[:, 0].astype(bool)
        # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
        scan(
            transition_fn,
            (first_capture_mask, z),
            jnp.swapaxes(capture_history[:, 1:], 0, 1),
        )



.. GENERATED FROM PYTHON SOURCE LINES 145-148

In our third model variant there is a survival probability phi_t for T-1
of the T time periods of the capture data (just like in model_2), but here
each phi_t is treated as a random effect.

.. GENERATED FROM PYTHON SOURCE LINES 148-196

.. code-block:: default



    def model_3(capture_history, sex):
        N, T = capture_history.shape
        phi_mean = numpyro.sample(
            "phi_mean", dist.Uniform(0.0, 1.0)
        )  # mean survival probability
        phi_logit_mean = logit(phi_mean)
        # controls temporal variability of survival probability
        phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0))
        rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

        def transition_fn(carry, y):
            first_capture_mask, z = carry
            with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}):
                phi_logit_t = numpyro.sample(
                    "phi_logit", dist.Normal(phi_logit_mean, phi_sigma)
                )
            phi_t = expit(phi_logit_t)
            with numpyro.plate("animals", N, dim=-1):
                with handlers.mask(mask=first_capture_mask):
                    mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                    # NumPyro exactly sums out the discrete states z_t.
                    z = numpyro.sample(
                        "z",
                        dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
                        infer={"enumerate": "parallel"},
                    )
                    mu_y_t = rho * z
                    numpyro.sample(
                        "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                    )

            first_capture_mask = first_capture_mask | y.astype(bool)
            return (first_capture_mask, z), None

        z = jnp.ones(N, dtype=jnp.int32)
        # we use this mask to eliminate extraneous log probabilities
        # that arise for a given individual before its first capture.
        first_capture_mask = capture_history[:, 0].astype(bool)
        # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
        scan(
            transition_fn,
            (first_capture_mask, z),
            jnp.swapaxes(capture_history[:, 1:], 0, 1),
        )



.. GENERATED FROM PYTHON SOURCE LINES 197-199

In our fourth model variant we include group-level fixed effects
for sex (male, female).

.. GENERATED FROM PYTHON SOURCE LINES 199-242

.. code-block:: default



    def model_4(capture_history, sex):
        N, T = capture_history.shape
        # survival probabilities for males/females
        phi_male = numpyro.sample("phi_male", dist.Uniform(0.0, 1.0))
        phi_female = numpyro.sample("phi_female", dist.Uniform(0.0, 1.0))
        # we construct a N-dimensional vector that contains the appropriate
        # phi for each individual given its sex (female = 0, male = 1)
        phi = sex * phi_male + (1.0 - sex) * phi_female
        rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

        def transition_fn(carry, y):
            first_capture_mask, z = carry
            with numpyro.plate("animals", N, dim=-1):
                with handlers.mask(mask=first_capture_mask):
                    mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask)
                    # NumPyro exactly sums out the discrete states z_t.
                    z = numpyro.sample(
                        "z",
                        dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
                        infer={"enumerate": "parallel"},
                    )
                    mu_y_t = rho * z
                    numpyro.sample(
                        "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                    )

            first_capture_mask = first_capture_mask | y.astype(bool)
            return (first_capture_mask, z), None

        z = jnp.ones(N, dtype=jnp.int32)
        # we use this mask to eliminate extraneous log probabilities
        # that arise for a given individual before its first capture.
        first_capture_mask = capture_history[:, 0].astype(bool)
        # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
        scan(
            transition_fn,
            (first_capture_mask, z),
            jnp.swapaxes(capture_history[:, 1:], 0, 1),
        )



.. GENERATED FROM PYTHON SOURCE LINES 243-249

In our final model variant we include both fixed group effects and fixed
time effects for the survival probability phi:
logit(phi_t) = beta_group + gamma_t
We need to take care that the model is not overparameterized; to do this
we effectively let a single scalar beta encode the difference in male
and female survival probabilities.

.. GENERATED FROM PYTHON SOURCE LINES 249-293

.. code-block:: default



    def model_5(capture_history, sex):
        N, T = capture_history.shape

        # phi_beta controls the survival probability differential
        # for males versus females (in logit space)
        phi_beta = numpyro.sample("phi_beta", dist.Normal(0.0, 10.0))
        phi_beta = sex * phi_beta
        rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

        def transition_fn(carry, y):
            first_capture_mask, z = carry
            phi_gamma_t = numpyro.sample("phi_gamma", dist.Normal(0.0, 10.0))
            phi_t = expit(phi_beta + phi_gamma_t)
            with numpyro.plate("animals", N, dim=-1):
                with handlers.mask(mask=first_capture_mask):
                    mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                    # NumPyro exactly sums out the discrete states z_t.
                    z = numpyro.sample(
                        "z",
                        dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
                        infer={"enumerate": "parallel"},
                    )
                    mu_y_t = rho * z
                    numpyro.sample(
                        "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                    )

            first_capture_mask = first_capture_mask | y.astype(bool)
            return (first_capture_mask, z), None

        z = jnp.ones(N, dtype=jnp.int32)
        # we use this mask to eliminate extraneous log probabilities
        # that arise for a given individual before its first capture.
        first_capture_mask = capture_history[:, 0].astype(bool)
        # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
        scan(
            transition_fn,
            (first_capture_mask, z),
            jnp.swapaxes(capture_history[:, 1:], 0, 1),
        )



.. GENERATED FROM PYTHON SOURCE LINES 294-295

Do inference

.. GENERATED FROM PYTHON SOURCE LINES 295-373

.. code-block:: default



    models = {
        name[len("model_") :]: model
        for name, model in globals().items()
        if name.startswith("model_")
    }


    def run_inference(model, capture_history, sex, rng_key, args):
        if args.algo == "NUTS":
            kernel = NUTS(model)
        elif args.algo == "HMC":
            kernel = HMC(model)
        mcmc = MCMC(
            kernel,
            num_warmup=args.num_warmup,
            num_samples=args.num_samples,
            num_chains=args.num_chains,
            progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
        )
        mcmc.run(rng_key, capture_history, sex)
        mcmc.print_summary()
        return mcmc.get_samples()


    def main(args):
        # load data
        if args.dataset == "dipper":
            capture_history, sex = load_dataset(DIPPER_VOLE, split="dipper", shuffle=False)[
                1
            ]()
        elif args.dataset == "vole":
            if args.model in ["4", "5"]:
                raise ValueError(
                    "Cannot run model_{} on meadow voles data, since we lack sex "
                    "information for these animals.".format(args.model)
                )
            (capture_history,) = load_dataset(DIPPER_VOLE, split="vole", shuffle=False)[1]()
            sex = None
        else:
            raise ValueError("Available datasets are 'dipper' and 'vole'.")

        N, T = capture_history.shape
        print(
            "Loaded {} capture history for {} individuals collected over {} time periods.".format(
                args.dataset, N, T
            )
        )

        model = models[args.model]
        rng_key = random.PRNGKey(args.rng_seed)
        run_inference(model, capture_history, sex, rng_key, args)


    if __name__ == "__main__":
        parser = argparse.ArgumentParser(
            description="CJS capture-recapture model for ecological data"
        )
        parser.add_argument(
            "-m",
            "--model",
            default="1",
            type=str,
            help="one of: {}".format(", ".join(sorted(models.keys()))),
        )
        parser.add_argument("-d", "--dataset", default="dipper", type=str)
        parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
        parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
        parser.add_argument("--num-chains", nargs="?", default=1, type=int)
        parser.add_argument(
            "--rng_seed", default=0, type=int, help="random number generator seed"
        )
        parser.add_argument(
            "--algo", default="NUTS", type=str, help='whether to run "NUTS" or "HMC"'
        )
        args = parser.parse_args()
        main(args)


.. _sphx_glr_download_examples_capture_recapture.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: capture_recapture.py <capture_recapture.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: capture_recapture.ipynb <capture_recapture.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
