
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/annotation.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_annotation.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_annotation.py:


Example: Bayesian Models of Annotation
======================================

In this example, we run MCMC for various crowdsourced annotation models in [1].

All models have discrete latent variables. Under the hood, we enumerate over
(marginalize out) those discrete latent sites in inference. Those models have different
complexity so they are great refererences for those who are new to Pyro/NumPyro
enumeration mechanism. We recommend readers compare the implementations with the
corresponding plate diagrams in [1] to see how concise a Pyro/NumPyro program is.

The interested readers can also refer to [3] for more explanation about enumeration.

The data is taken from Table 1 of reference [2].

Currently, this example does not include postprocessing steps to deal with "Label
Switching" issue (mentioned in section 6.2 of [1]).

**References:**

    1. Paun, S., Carpenter, B., Chamberlain, J., Hovy, D., Kruschwitz, U.,
       and Poesio, M. (2018). "Comparing bayesian models of annotation"
       (https://www.aclweb.org/anthology/Q18-1040/)
    2. Dawid, A. P., and Skene, A. M. (1979).
       "Maximum likelihood estimation of observer error‐rates using the EM algorithm"
    3. "Inference with Discrete Latent Variables"
       (http://pyro.ai/examples/enumeration.html)

.. GENERATED FROM PYTHON SOURCE LINES 34-304

.. code-block:: default


    import argparse
    import os

    import numpy as np

    from jax import nn, random, vmap
    import jax.numpy as jnp

    import numpyro
    from numpyro import handlers
    import numpyro.distributions as dist
    from numpyro.infer import MCMC, NUTS, Predictive
    from numpyro.infer.reparam import LocScaleReparam
    from numpyro.ops.indexing import Vindex


    def get_data():
        """
        :return: a tuple of annotator indices and class indices. The first term has shape
            `num_positions` whose entries take values from `0` to `num_annotators - 1`.
            The second term has shape `num_items x num_positions` whose entries take values
            from `0` to `num_classes - 1`.
        """
        # NB: the first annotator assessed each item 3 times
        positions = np.array([1, 1, 1, 2, 3, 4, 5])
        # fmt: off
        annotations = np.array(
            [[1, 1, 1, 1, 1, 1, 1], [3, 3, 3, 4, 3, 3, 4], [1, 1, 2, 2, 1, 2, 2],
             [2, 2, 2, 3, 1, 2, 1], [2, 2, 2, 3, 2, 2, 2], [2, 2, 2, 3, 3, 2, 2],
             [1, 2, 2, 2, 1, 1, 1], [3, 3, 3, 3, 4, 3, 3], [2, 2, 2, 2, 2, 2, 3],
             [2, 3, 2, 2, 2, 2, 3], [4, 4, 4, 4, 4, 4, 4], [2, 2, 2, 3, 3, 4, 3],
             [1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 3, 2, 1, 2], [1, 2, 1, 1, 1, 1, 1],
             [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
             [2, 2, 2, 2, 2, 2, 1], [2, 2, 2, 1, 3, 2, 2], [2, 2, 2, 2, 2, 2, 2],
             [2, 2, 2, 2, 2, 2, 1], [2, 2, 2, 3, 2, 2, 2], [2, 2, 1, 2, 2, 2, 2],
             [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [2, 3, 2, 2, 2, 2, 2],
             [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 2, 1, 1, 2, 1],
             [1, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 2, 3, 3], [1, 1, 1, 1, 1, 1, 1],
             [2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 3, 2, 3, 2], [4, 3, 3, 4, 3, 4, 3],
             [2, 2, 1, 2, 2, 3, 2], [2, 3, 2, 3, 2, 3, 3], [3, 3, 3, 3, 4, 3, 2],
             [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 2, 1, 2, 1, 1, 1],
             [2, 3, 2, 2, 2, 2, 2], [1, 2, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2]])
        # fmt: on
        # we minus 1 because in Python, the first index is 0
        return positions - 1, annotations - 1


    def multinomial(annotations):
        """
        This model corresponds to the plate diagram in Figure 1 of reference [1].
        """
        num_classes = int(np.max(annotations)) + 1
        num_items, num_positions = annotations.shape

        with numpyro.plate("class", num_classes):
            zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes)))

        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

        with numpyro.plate("item", num_items, dim=-2):
            c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

            with numpyro.plate("position", num_positions):
                numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)


    def dawid_skene(positions, annotations):
        """
        This model corresponds to the plate diagram in Figure 2 of reference [1].
        """
        num_annotators = int(np.max(positions)) + 1
        num_classes = int(np.max(annotations)) + 1
        num_items, num_positions = annotations.shape

        with numpyro.plate("annotator", num_annotators, dim=-2):
            with numpyro.plate("class", num_classes):
                beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes)))

        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

        with numpyro.plate("item", num_items, dim=-2):
            c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

            # here we use Vindex to allow broadcasting for the second index `c`
            # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
            with numpyro.plate("position", num_positions):
                numpyro.sample(
                    "y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations
                )


    def mace(positions, annotations):
        """
        This model corresponds to the plate diagram in Figure 3 of reference [1].
        """
        num_annotators = int(np.max(positions)) + 1
        num_classes = int(np.max(annotations)) + 1
        num_items, num_positions = annotations.shape

        with numpyro.plate("annotator", num_annotators):
            epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10)))
            theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

        with numpyro.plate("item", num_items, dim=-2):
            c = numpyro.sample(
                "c",
                dist.DiscreteUniform(0, num_classes - 1),
                infer={"enumerate": "parallel"},
            )

            with numpyro.plate("position", num_positions):
                s = numpyro.sample(
                    "s",
                    dist.Bernoulli(1 - theta[positions]),
                    infer={"enumerate": "parallel"},
                )
                probs = jnp.where(
                    s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]
                )
                numpyro.sample("y", dist.Categorical(probs), obs=annotations)


    def hierarchical_dawid_skene(positions, annotations):
        """
        This model corresponds to the plate diagram in Figure 4 of reference [1].
        """
        num_annotators = int(np.max(positions)) + 1
        num_classes = int(np.max(annotations)) + 1
        num_items, num_positions = annotations.shape

        with numpyro.plate("class", num_classes):
            # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
            # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
            # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
            zeta = numpyro.sample(
                "zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
            )
            omega = numpyro.sample(
                "Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
            )

        with numpyro.plate("annotator", num_annotators, dim=-2):
            with numpyro.plate("class", num_classes):
                # non-centered parameterization
                with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                    beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                # pad 0 to the last item
                beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

        with numpyro.plate("item", num_items, dim=-2):
            c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

            with numpyro.plate("position", num_positions):
                logits = Vindex(beta)[positions, c, :]
                numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)


    def item_difficulty(annotations):
        """
        This model corresponds to the plate diagram in Figure 5 of reference [1].
        """
        num_classes = int(np.max(annotations)) + 1
        num_items, num_positions = annotations.shape

        with numpyro.plate("class", num_classes):
            eta = numpyro.sample(
                "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
            )
            chi = numpyro.sample(
                "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
            )

        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

        with numpyro.plate("item", num_items, dim=-2):
            c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

            with handlers.reparam(config={"theta": LocScaleReparam(0)}):
                theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1))
                theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

            with numpyro.plate("position", annotations.shape[-1]):
                numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)


    def logistic_random_effects(positions, annotations):
        """
        This model corresponds to the plate diagram in Figure 5 of reference [1].
        """
        num_annotators = int(np.max(positions)) + 1
        num_classes = int(np.max(annotations)) + 1
        num_items, num_positions = annotations.shape

        with numpyro.plate("class", num_classes):
            zeta = numpyro.sample(
                "zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)
            )
            omega = numpyro.sample(
                "Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
            )
            chi = numpyro.sample(
                "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)
            )

        with numpyro.plate("annotator", num_annotators, dim=-2):
            with numpyro.plate("class", num_classes):
                with handlers.reparam(config={"beta": LocScaleReparam(0)}):
                    beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
                    beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)])

        pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

        with numpyro.plate("item", num_items, dim=-2):
            c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

            with handlers.reparam(config={"theta": LocScaleReparam(0)}):
                theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1))
                theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)])

            with numpyro.plate("position", num_positions):
                logits = Vindex(beta)[positions, c, :] - theta
                numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)


    NAME_TO_MODEL = {
        "mn": multinomial,
        "ds": dawid_skene,
        "mace": mace,
        "hds": hierarchical_dawid_skene,
        "id": item_difficulty,
        "lre": logistic_random_effects,
    }


    def main(args):
        annotators, annotations = get_data()
        model = NAME_TO_MODEL[args.model]
        data = (
            (annotations,)
            if model in [multinomial, item_difficulty]
            else (annotators, annotations)
        )

        mcmc = MCMC(
            NUTS(model),
            num_warmup=args.num_warmup,
            num_samples=args.num_samples,
            num_chains=args.num_chains,
            progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
        )
        mcmc.run(random.PRNGKey(0), *data)
        mcmc.print_summary()

        posterior_samples = mcmc.get_samples()
        predictive = Predictive(model, posterior_samples, infer_discrete=True)
        discrete_samples = predictive(random.PRNGKey(1), *data)

        item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(
            discrete_samples["c"].squeeze(-1)
        )
        print("Histogram of the predicted class of each item:")
        row_format = "{:>10}" * 5
        print(row_format.format("", *["c={}".format(i) for i in range(4)]))
        for i, row in enumerate(item_class):
            print(row_format.format(f"item[{i}]", *row))



.. GENERATED FROM PYTHON SOURCE LINES 305-320

.. note::
    In the above inference code, we marginalized the discrete latent variables `c`
    hence `mcmc.get_samples(...)` does not include samples of `c`. We then utilize
    `Predictive(..., infer_discrete=True)` to get posterior samples for `c`, which
    is stored in `discrete_samples`. To merge those discrete samples into the `mcmc`
    instance, we can use the following pattern::

        chain_discrete_samples = jax.tree_util.tree_map(
            lambda x: x.reshape((args.num_chains, args.num_samples) + x.shape[1:]),
            discrete_samples)
        mcmc.get_samples().update(discrete_samples)
        mcmc.get_samples(group_by_chain=True).update(chain_discrete_samples)

    This is useful when we want to pass the `mcmc` instance to `arviz` through
    `arviz.from_numpyro(mcmc)`.

.. GENERATED FROM PYTHON SOURCE LINES 320-343

.. code-block:: default



    if __name__ == "__main__":
        assert numpyro.__version__.startswith("0.13.2")
        parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
        parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
        parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
        parser.add_argument("--num-chains", nargs="?", default=1, type=int)
        parser.add_argument(
            "--model",
            nargs="?",
            default="ds",
            help='one of "mn" (multinomial), "ds" (dawid_skene), "mace",'
            ' "hds" (hierarchical_dawid_skene),'
            ' "id" (item_difficulty), "lre" (logistic_random_effects)',
        )
        parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
        args = parser.parse_args()

        numpyro.set_platform(args.device)
        numpyro.set_host_device_count(args.num_chains)

        main(args)


.. _sphx_glr_download_examples_annotation.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: annotation.py <annotation.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: annotation.ipynb <annotation.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
