{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Toy Mixture Model with Discrete Enumeration\n\nA toy mixture model to provide a simple example for implementing discrete enumeration::\n\n    (A) -> [B] -> (C)\n\n``A`` is an observed Bernoulli variable with Beta prior. ``B`` is a hidden variable which\nis a mixture of two Bernoulli distributions (with Beta priors), chosen by ``A`` being true or false.\n``C`` is observed, and like ``B``, is a mixture of two Bernoulli distributions (with Beta priors),\nchosen by ``B`` being true or false. There is a plate over the three variables for ``num_obs``\nindependent observations of data.\n\nBecause ``B`` is hidden and discrete we wish to marginalize it out of the model. This is done by:\n\n1. marking the model with ``@config_enumerate``\n2. marking the ``B`` sample site in the model with ``infer={\"enumerate\": \"parallel\"}``\n3. passing ``SVI`` the ``TraceEnum_ELBO`` loss function\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\n\nimport matplotlib.pyplot as plt\n\nfrom jax import random\nimport jax.numpy as jnp\nimport optax\n\nimport numpyro\nfrom numpyro import handlers\nfrom numpyro.contrib.funsor import config_enumerate\nimport numpyro.distributions as dist\nfrom numpyro.distributions import constraints\nfrom numpyro.infer import SVI, TraceEnum_ELBO\nfrom numpyro.ops.indexing import Vindex\n\n\ndef main(args):\n    num_obs = args.num_obs\n    num_steps = args.num_steps\n    prior, CPDs, data = handlers.seed(generate_data, random.PRNGKey(0))(num_obs)\n    posterior_params = train(prior, data, num_steps, num_obs)\n    evaluate(CPDs, posterior_params)\n\n\ndef generate_data(num_obs):\n    # domain = [False, True]\n    prior = {\n        \"A\": jnp.array([1.0, 10.0]),\n        \"B\": jnp.array([[10.0, 1.0], [1.0, 10.0]]),\n        \"C\": jnp.array([[10.0, 1.0], [1.0, 10.0]]),\n    }\n    CPDs = {\n        \"p_A\": numpyro.sample(\"p_A\", dist.Beta(prior[\"A\"][0], prior[\"A\"][1])),\n        \"p_B\": numpyro.sample(\"p_B\", dist.Beta(prior[\"B\"][:, 0], prior[\"B\"][:, 1])),\n        \"p_C\": numpyro.sample(\"p_C\", dist.Beta(prior[\"C\"][:, 0], prior[\"C\"][:, 1])),\n    }\n    data = {\"A\": numpyro.sample(\"A\", dist.Bernoulli(jnp.ones(num_obs) * CPDs[\"p_A\"]))}\n    data[\"B\"] = numpyro.sample(\"B\", dist.Bernoulli(CPDs[\"p_B\"][data[\"A\"]]))\n    data[\"C\"] = numpyro.sample(\"C\", dist.Bernoulli(CPDs[\"p_C\"][data[\"B\"]]))\n    return prior, CPDs, data\n\n\n@config_enumerate\ndef model(prior, obs, num_obs):\n    p_A = numpyro.sample(\"p_A\", dist.Beta(1, 1))\n    p_B = numpyro.sample(\"p_B\", dist.Beta(jnp.ones(2), jnp.ones(2)).to_event(1))\n    p_C = numpyro.sample(\"p_C\", dist.Beta(jnp.ones(2), jnp.ones(2)).to_event(1))\n    with numpyro.plate(\"data_plate\", num_obs):\n        A = numpyro.sample(\"A\", dist.Bernoulli(p_A), obs=obs[\"A\"])\n        # Vindex used to ensure proper indexing into the enumerated sample sites\n        B = numpyro.sample(\n            \"B\",\n            dist.Bernoulli(Vindex(p_B)[A]),\n            infer={\"enumerate\": \"parallel\"},\n        )\n        numpyro.sample(\"C\", dist.Bernoulli(Vindex(p_C)[B]), obs=obs[\"C\"])\n\n\ndef guide(prior, obs, num_obs):\n    a = numpyro.param(\"a\", prior[\"A\"], constraint=constraints.positive)\n    numpyro.sample(\"p_A\", dist.Beta(a[0], a[1]))\n    b = numpyro.param(\"b\", prior[\"B\"], constraint=constraints.positive)\n    numpyro.sample(\"p_B\", dist.Beta(b[:, 0], b[:, 1]).to_event(1))\n    c = numpyro.param(\"c\", prior[\"C\"], constraint=constraints.positive)\n    numpyro.sample(\"p_C\", dist.Beta(c[:, 0], c[:, 1]).to_event(1))\n\n\ndef train(prior, data, num_steps, num_obs):\n    elbo = TraceEnum_ELBO()\n    svi = SVI(model, guide, optax.adam(learning_rate=0.01), loss=elbo)\n    svi_result = svi.run(random.PRNGKey(0), num_steps, prior, data, num_obs)\n    plt.figure()\n    plt.plot(svi_result.losses)\n    plt.show()\n    posterior_params = svi_result.params.copy()\n    posterior_params[\"a\"] = posterior_params[\"a\"][\n        None, :\n    ]  # reshape to same as other variables\n    return posterior_params\n\n\ndef evaluate(CPDs, posterior_params):\n    true_p_A, pred_p_A = get_true_pred_CPDs(CPDs[\"p_A\"], posterior_params[\"a\"])\n    true_p_B, pred_p_B = get_true_pred_CPDs(CPDs[\"p_B\"], posterior_params[\"b\"])\n    true_p_C, pred_p_C = get_true_pred_CPDs(CPDs[\"p_C\"], posterior_params[\"c\"])\n    print(\"\\np_A = True\")\n    print(\"actual:   \", true_p_A)\n    print(\"predicted:\", pred_p_A)\n    print(\"\\np_B = True | A = False/True\")\n    print(\"actual:   \", true_p_B)\n    print(\"predicted:\", pred_p_B)\n    print(\"\\np_C = True | B = False/True\")\n    print(\"actual:   \", true_p_C)\n    print(\"predicted:\", pred_p_C)\n\n\ndef get_true_pred_CPDs(CPD, posterior_param):\n    true_p = CPD\n    pred_p = posterior_param[:, 0] / jnp.sum(posterior_param, axis=1)\n    return true_p, pred_p\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Toy mixture model\")\n    parser.add_argument(\"-n\", \"--num-steps\", default=4000, type=int)\n    parser.add_argument(\"-o\", \"--num-obs\", default=10000, type=int)\n    args = parser.parse_args()\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}