{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Enumerate Hidden Markov Model\n\nThis example is ported from [1], which shows how to marginalize out\ndiscrete model variables in Pyro.\n\nThis combines MCMC with a variable elimination algorithm, where we\nuse enumeration to exactly marginalize out some variables from the\njoint density.\n\nTo marginalize out discrete variables ``x``:\n\n1. Verify that the variable dependency structure in your model\n   admits tractable inference, i.e. the dependency graph among\n   enumerated variables should have narrow treewidth.\n2. Ensure your model can handle broadcasting of the sample values\n   of those variables.\n\nNote that difference from [1], which uses Python loop, here we use\n:func:`~numpryo.contrib.control_flow.scan` to reduce compilation\ntimes (only one step needs to be compiled) of the model. Under the\nhood, `scan` stacks all the priors' parameters and values into\nan additional time dimension. This allows us computing the joint\ndensity in parallel. In addition, the stacked form allows us\nto use the parallel-scan algorithm in [2], which reduces parallel\ncomplexity from O(length) to O(log(length)).\n\nData are taken from [3]. However, the original source of the data\nseems to be the Institut fuer Algorithmen und Kognitive Systeme\nat Universitaet Karlsruhe.\n\n**References:**\n\n    1. *Pyro's Hidden Markov Model example*,\n       (https://pyro.ai/examples/hmm.html)\n    2. *Temporal Parallelization of Bayesian Smoothers*,\n       Simo Sarkka, Angel F. Garcia-Fernandez\n       (https://arxiv.org/abs/1905.13002)\n    3. *Modeling Temporal Dependencies in High-Dimensional Sequences:\n       Application to Polyphonic Music Generation and Transcription*,\n       Boulanger-Lewandowski, N., Bengio, Y. and Vincent, P.\n    4. *Tensor Variable Elimination for Plated Factor Graphs*,\n       Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu,\n       Neeraj Pradhan, Alexander Rush, Noah Goodman (https://arxiv.org/abs/1902.03210)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport logging\nimport os\nimport time\n\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nfrom numpyro.contrib.control_flow import scan\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import JSB_CHORALES, load_dataset\nfrom numpyro.handlers import mask\nfrom numpyro.infer import HMC, MCMC, NUTS\nfrom numpyro.ops.indexing import Vindex\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's start with a simple Hidden Markov Model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#     x[t-1] --> x[t] --> x[t+1]\n#        |        |         |\n#        V        V         V\n#     y[t-1]     y[t]     y[t+1]\n#\n# This model includes a plate for the data_dim = 44 keys on the piano. This\n# model has two \"style\" parameters probs_x and probs_y that we'll draw from a\n# prior. The latent state is x, and the observed state is y.\ndef model_1(sequences, lengths, args, include_prior=True):\n    num_sequences, max_length, data_dim = sequences.shape\n    with mask(mask=include_prior):\n        probs_x = numpyro.sample(\n            \"probs_x\", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1)\n        )\n        probs_y = numpyro.sample(\n            \"probs_y\",\n            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),\n        )\n\n    def transition_fn(carry, y):\n        x_prev, t = carry\n        with numpyro.plate(\"sequences\", num_sequences, dim=-2):\n            with mask(mask=(t < lengths)[..., None]):\n                x = numpyro.sample(\n                    \"x\",\n                    dist.Categorical(probs_x[x_prev]),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                with numpyro.plate(\"tones\", data_dim, dim=-1):\n                    numpyro.sample(\"y\", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)\n        return (x, t + 1), None\n\n    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)\n    # NB swapaxes: we move time dimension of `sequences` to the front to scan over it\n    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next let's add a dependency of y[t] on y[t-1].\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#     x[t-1] --> x[t] --> x[t+1]\n#        |        |         |\n#        V        V         V\n#     y[t-1] --> y[t] --> y[t+1]\ndef model_2(sequences, lengths, args, include_prior=True):\n    num_sequences, max_length, data_dim = sequences.shape\n    with mask(mask=include_prior):\n        probs_x = numpyro.sample(\n            \"probs_x\", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1)\n        )\n\n        probs_y = numpyro.sample(\n            \"probs_y\",\n            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3),\n        )\n\n    def transition_fn(carry, y):\n        x_prev, y_prev, t = carry\n        with numpyro.plate(\"sequences\", num_sequences, dim=-2):\n            with mask(mask=(t < lengths)[..., None]):\n                x = numpyro.sample(\n                    \"x\",\n                    dist.Categorical(probs_x[x_prev]),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                # Note the broadcasting tricks here: to index probs_y on tensors x and y,\n                # we also need a final tensor for the tones dimension. This is conveniently\n                # provided by the plate associated with that dimension.\n                with numpyro.plate(\"tones\", data_dim, dim=-1) as tones:\n                    y = numpyro.sample(\n                        \"y\", dist.Bernoulli(probs_y[x, y_prev, tones]), obs=y\n                    )\n        return (x, y, t + 1), None\n\n    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)\n    y_init = jnp.zeros((num_sequences, data_dim), dtype=jnp.int32)\n    scan(transition_fn, (x_init, y_init, 0), jnp.swapaxes(sequences, 0, 1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next consider a Factorial HMM with two hidden states.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#    w[t-1] ----> w[t] ---> w[t+1]\n#        \\ x[t-1] --\\-> x[t] --\\-> x[t+1]\n#         \\  /       \\  /       \\  /\n#          \\/         \\/         \\/\n#        y[t-1]      y[t]      y[t+1]\n#\n# Note that since the joint distribution of each y[t] depends on two variables,\n# those two variables become dependent. Therefore during enumeration, the\n# entire joint space of these variables w[t],x[t] needs to be enumerated.\n# For that reason, we set the dimension of each to the square root of the\n# target hidden dimension.\ndef model_3(sequences, lengths, args, include_prior=True):\n    num_sequences, max_length, data_dim = sequences.shape\n    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x\n    with mask(mask=include_prior):\n        probs_w = numpyro.sample(\n            \"probs_w\", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)\n        )\n        probs_x = numpyro.sample(\n            \"probs_x\", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)\n        )\n        probs_y = numpyro.sample(\n            \"probs_y\",\n            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3),\n        )\n\n    def transition_fn(carry, y):\n        w_prev, x_prev, t = carry\n        with numpyro.plate(\"sequences\", num_sequences, dim=-2):\n            with mask(mask=(t < lengths)[..., None]):\n                w = numpyro.sample(\n                    \"w\",\n                    dist.Categorical(probs_w[w_prev]),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                x = numpyro.sample(\n                    \"x\",\n                    dist.Categorical(probs_x[x_prev]),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                # Note the broadcasting tricks here: to index probs_y on tensors x and y,\n                # we also need a final tensor for the tones dimension. This is conveniently\n                # provided by the plate associated with that dimension.\n                with numpyro.plate(\"tones\", data_dim, dim=-1) as tones:\n                    numpyro.sample(\"y\", dist.Bernoulli(probs_y[w, x, tones]), obs=y)\n        return (w, x, t + 1), None\n\n    w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)\n    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)\n    scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "By adding a dependency of x on w, we generalize to a\nDynamic Bayesian Network.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#     w[t-1] ----> w[t] ---> w[t+1]\n#        |  \\       |  \\       |   \\\n#        | x[t-1] ----> x[t] ----> x[t+1]\n#        |   /      |   /      |   /\n#        V  /       V  /       V  /\n#     y[t-1]       y[t]      y[t+1]\n#\n# Note that message passing here has roughly the same cost as with the\n# Factorial HMM, but this model has more parameters.\ndef model_4(sequences, lengths, args, include_prior=True):\n    num_sequences, max_length, data_dim = sequences.shape\n    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x\n    with mask(mask=include_prior):\n        probs_w = numpyro.sample(\n            \"probs_w\", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)\n        )\n        probs_x = numpyro.sample(\n            \"probs_x\",\n            dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1)\n            .expand_by([hidden_dim])\n            .to_event(2),\n        )\n        probs_y = numpyro.sample(\n            \"probs_y\",\n            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3),\n        )\n\n    def transition_fn(carry, y):\n        w_prev, x_prev, t = carry\n        with numpyro.plate(\"sequences\", num_sequences, dim=-2):\n            with mask(mask=(t < lengths)[..., None]):\n                w = numpyro.sample(\n                    \"w\",\n                    dist.Categorical(probs_w[w_prev]),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                x = numpyro.sample(\n                    \"x\",\n                    dist.Categorical(Vindex(probs_x)[w, x_prev]),\n                    infer={\"enumerate\": \"parallel\"},\n                )\n                with numpyro.plate(\"tones\", data_dim, dim=-1) as tones:\n                    numpyro.sample(\"y\", dist.Bernoulli(probs_y[w, x, tones]), obs=y)\n        return (w, x, t + 1), None\n\n    w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)\n    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)\n    scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next let's consider a second-order HMM model\nin which x[t+1] depends on both x[t] and x[t-1].\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#                     _______>______\n#         _____>_____/______        \\\n#        /          /       \\        \\\n#     x[t-1] --> x[t] --> x[t+1] --> x[t+2]\n#        |        |          |          |\n#        V        V          V          V\n#     y[t-1]     y[t]     y[t+1]     y[t+2]\n#\n#  Note that in this model (in contrast to the previous model) we treat\n#  the transition and emission probabilities as parameters (so they have no prior).\n#\n# Note that this is the \"2HMM\" model in reference [4].\ndef model_6(sequences, lengths, args, include_prior=False):\n    num_sequences, max_length, data_dim = sequences.shape\n\n    with mask(mask=include_prior):\n        # Explicitly parameterize the full tensor of transition probabilities, which\n        # has hidden_dim cubed entries.\n        probs_x = numpyro.sample(\n            \"probs_x\",\n            dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1)\n            .expand([args.hidden_dim, args.hidden_dim])\n            .to_event(2),\n        )\n\n        probs_y = numpyro.sample(\n            \"probs_y\",\n            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),\n        )\n\n    def transition_fn(carry, y):\n        x_prev, x_curr, t = carry\n        with numpyro.plate(\"sequences\", num_sequences, dim=-2):\n            with mask(mask=(t < lengths)[..., None]):\n                probs_x_t = Vindex(probs_x)[x_prev, x_curr]\n                x_prev, x_curr = x_curr, numpyro.sample(\n                    \"x\", dist.Categorical(probs_x_t), infer={\"enumerate\": \"parallel\"}\n                )\n                with numpyro.plate(\"tones\", data_dim, dim=-1):\n                    probs_y_t = probs_y[x_curr.squeeze(-1)]\n                    numpyro.sample(\"y\", dist.Bernoulli(probs_y_t), obs=y)\n        return (x_prev, x_curr, t + 1), None\n\n    x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32)\n    x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32)\n    scan(transition_fn, (x_prev, x_curr, 0), jnp.swapaxes(sequences, 0, 1), history=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Do inference\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "models = {\n    name[len(\"model_\") :]: model\n    for name, model in globals().items()\n    if name.startswith(\"model_\")\n}\n\n\ndef main(args):\n    model = models[args.model]\n\n    _, fetch = load_dataset(JSB_CHORALES, split=\"train\", shuffle=False)\n    lengths, sequences = fetch()\n    if args.num_sequences:\n        sequences = sequences[0 : args.num_sequences]\n        lengths = lengths[0 : args.num_sequences]\n\n    logger.info(\"-\" * 40)\n    logger.info(\"Training {} on {} sequences\".format(model.__name__, len(sequences)))\n\n    # find all the notes that are present at least once in the training set\n    present_notes = (sequences == 1).sum(0).sum(0) > 0\n    # remove notes that are never played (we remove 37/88 notes with default args)\n    sequences = sequences[:, :, present_notes]\n\n    if args.truncate:\n        lengths = lengths.clip(0, args.truncate)\n        sequences = sequences[:, : args.truncate]\n\n    logger.info(\"Each sequence has shape {}\".format(sequences[0].shape))\n    logger.info(\"Starting inference...\")\n    rng_key = random.PRNGKey(2)\n    start = time.time()\n    kernel = {\"nuts\": NUTS, \"hmc\": HMC}[args.kernel](model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, sequences, lengths, args=args)\n    mcmc.print_summary()\n    logger.info(\"\\nMCMC elapsed time: {}\".format(time.time() - start))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"HMC for HMMs\")\n    parser.add_argument(\n        \"-m\",\n        \"--model\",\n        default=\"1\",\n        type=str,\n        help=\"one of: {}\".format(\", \".join(sorted(models.keys()))),\n    )\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"-d\", \"--hidden-dim\", default=16, type=int)\n    parser.add_argument(\"-t\", \"--truncate\", type=int)\n    parser.add_argument(\"--num-sequences\", type=int)\n    parser.add_argument(\"--kernel\", default=\"nuts\", type=str)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=500, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}