{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Hidden Markov Model\n\nIn this example, we will follow [1] to construct a semi-supervised Hidden Markov\nModel for a generative model with observations are words and latent variables\nare categories. Instead of automatically marginalizing all discrete latent\nvariables (as in [2]), we will use the \"forward algorithm\" (which exploits the\nconditional independent of a Markov model - see [3]) to iteratively do this\nmarginalization.\n\nThe semi-supervised problem is chosen instead of an unsupervised one because it\nis hard to make the inference works for an unsupervised model (see the\ndiscussion [4]). On the other hand, this example also illustrates the usage of\nJAX's `lax.scan` primitive. The primitive will greatly improve compiling for the\nmodel.\n\n**References:**\n\n    1. https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html\n    2. http://pyro.ai/examples/hmm.html\n    3. https://en.wikipedia.org/wiki/Forward_algorithm\n    4. https://discourse.pymc.io/t/how-to-marginalized-markov-chain-with-categorical/2230\n\n<img src=\"file://../_static/img/examples/hmm.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\nimport time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom scipy.stats import gaussian_kde\n\nfrom jax import lax, random\nimport jax.numpy as jnp\nfrom jax.scipy.special import logsumexp\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS\n\n\ndef simulate_data(\n    rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data\n):\n    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)\n\n    transition_prior = jnp.ones(num_categories)\n    emission_prior = jnp.repeat(0.1, num_words)\n\n    transition_prob = dist.Dirichlet(transition_prior).sample(\n        key=rng_key_transition, sample_shape=(num_categories,)\n    )\n    emission_prob = dist.Dirichlet(emission_prior).sample(\n        key=rng_key_emission, sample_shape=(num_categories,)\n    )\n\n    start_prob = jnp.repeat(1.0 / num_categories, num_categories)\n    categories, words = [], []\n    for t in range(num_supervised_data + num_unsupervised_data):\n        rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)\n        if t == 0 or t == num_supervised_data:\n            category = dist.Categorical(start_prob).sample(key=rng_key_transition)\n        else:\n            category = dist.Categorical(transition_prob[category]).sample(\n                key=rng_key_transition\n            )\n        word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)\n        categories.append(category)\n        words.append(word)\n\n    # split into supervised data and unsupervised data\n    categories, words = jnp.stack(categories), jnp.stack(words)\n    supervised_categories = categories[:num_supervised_data]\n    supervised_words = words[:num_supervised_data]\n    unsupervised_words = words[num_supervised_data:]\n    return (\n        transition_prior,\n        emission_prior,\n        transition_prob,\n        emission_prob,\n        supervised_categories,\n        supervised_words,\n        unsupervised_words,\n    )\n\n\ndef forward_one_step(prev_log_prob, curr_word, transition_log_prob, emission_log_prob):\n    log_prob_tmp = jnp.expand_dims(prev_log_prob, axis=1) + transition_log_prob\n    log_prob = log_prob_tmp + emission_log_prob[:, curr_word]\n    return logsumexp(log_prob, axis=0)\n\n\ndef forward_log_prob(\n    init_log_prob, words, transition_log_prob, emission_log_prob, unroll_loop=False\n):\n    # Note: The following naive implementation will make it very slow to compile\n    # and do inference. So we use lax.scan instead.\n    #\n    # >>> log_prob = init_log_prob\n    # >>> for word in words:\n    # ...     log_prob = forward_one_step(log_prob, word, transition_log_prob, emission_log_prob)\n    def scan_fn(log_prob, word):\n        return (\n            forward_one_step(log_prob, word, transition_log_prob, emission_log_prob),\n            None,  # we don't need to collect during scan\n        )\n\n    if unroll_loop:\n        log_prob = init_log_prob\n        for word in words:\n            log_prob = forward_one_step(\n                log_prob, word, transition_log_prob, emission_log_prob\n            )\n    else:\n        log_prob, _ = lax.scan(scan_fn, init_log_prob, words)\n    return log_prob\n\n\ndef semi_supervised_hmm(\n    transition_prior,\n    emission_prior,\n    supervised_categories,\n    supervised_words,\n    unsupervised_words,\n    unroll_loop=False,\n):\n    num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0]\n    transition_prob = numpyro.sample(\n        \"transition_prob\",\n        dist.Dirichlet(\n            jnp.broadcast_to(transition_prior, (num_categories, num_categories))\n        ),\n    )\n    emission_prob = numpyro.sample(\n        \"emission_prob\",\n        dist.Dirichlet(jnp.broadcast_to(emission_prior, (num_categories, num_words))),\n    )\n\n    # models supervised data;\n    # here we don't make any assumption about the first supervised category, in other words,\n    # we place a flat/uniform prior on it.\n    numpyro.sample(\n        \"supervised_categories\",\n        dist.Categorical(transition_prob[supervised_categories[:-1]]),\n        obs=supervised_categories[1:],\n    )\n    numpyro.sample(\n        \"supervised_words\",\n        dist.Categorical(emission_prob[supervised_categories]),\n        obs=supervised_words,\n    )\n\n    # computes log prob of unsupervised data\n    transition_log_prob = jnp.log(transition_prob)\n    emission_log_prob = jnp.log(emission_prob)\n    init_log_prob = emission_log_prob[:, unsupervised_words[0]]\n    log_prob = forward_log_prob(\n        init_log_prob,\n        unsupervised_words[1:],\n        transition_log_prob,\n        emission_log_prob,\n        unroll_loop,\n    )\n    log_prob = logsumexp(log_prob, axis=0, keepdims=True)\n    # inject log_prob to potential function\n    numpyro.factor(\"forward_log_prob\", log_prob)\n\n\ndef print_results(posterior, transition_prob, emission_prob):\n    header = semi_supervised_hmm.__name__ + \" - TRAIN\"\n    columns = [\"\", \"ActualProb\", \"Pred(p25)\", \"Pred(p50)\", \"Pred(p75)\"]\n    header_format = \"{:>20} {:>10} {:>10} {:>10} {:>10}\"\n    row_format = \"{:>20} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}\"\n    print(\"\\n\", \"=\" * 20 + header + \"=\" * 20, \"\\n\")\n    print(header_format.format(*columns))\n\n    quantiles = np.quantile(posterior[\"transition_prob\"], [0.25, 0.5, 0.75], axis=0)\n    for i in range(transition_prob.shape[0]):\n        for j in range(transition_prob.shape[1]):\n            idx = \"transition[{},{}]\".format(i, j)\n            print(\n                row_format.format(idx, transition_prob[i, j], *quantiles[:, i, j]), \"\\n\"\n            )\n\n    quantiles = np.quantile(posterior[\"emission_prob\"], [0.25, 0.5, 0.75], axis=0)\n    for i in range(emission_prob.shape[0]):\n        for j in range(emission_prob.shape[1]):\n            idx = \"emission[{},{}]\".format(i, j)\n            print(\n                row_format.format(idx, emission_prob[i, j], *quantiles[:, i, j]), \"\\n\"\n            )\n\n\ndef main(args):\n    print(\"Simulating data...\")\n    (\n        transition_prior,\n        emission_prior,\n        transition_prob,\n        emission_prob,\n        supervised_categories,\n        supervised_words,\n        unsupervised_words,\n    ) = simulate_data(\n        random.PRNGKey(1),\n        num_categories=args.num_categories,\n        num_words=args.num_words,\n        num_supervised_data=args.num_supervised,\n        num_unsupervised_data=args.num_unsupervised,\n    )\n    print(\"Starting inference...\")\n    rng_key = random.PRNGKey(2)\n    start = time.time()\n    kernel = NUTS(semi_supervised_hmm)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(\n        rng_key,\n        transition_prior,\n        emission_prior,\n        supervised_categories,\n        supervised_words,\n        unsupervised_words,\n        args.unroll_loop,\n    )\n    samples = mcmc.get_samples()\n    print_results(samples, transition_prob, emission_prob)\n    print(\"\\nMCMC elapsed time:\", time.time() - start)\n\n    # make plots\n    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)\n\n    x = np.linspace(0, 1, 101)\n    for i in range(transition_prob.shape[0]):\n        for j in range(transition_prob.shape[1]):\n            ax.plot(\n                x,\n                gaussian_kde(samples[\"transition_prob\"][:, i, j])(x),\n                label=\"trans_prob[{}, {}], true value = {:.2f}\".format(\n                    i, j, transition_prob[i, j]\n                ),\n            )\n    ax.set(\n        xlabel=\"Probability\",\n        ylabel=\"Frequency\",\n        title=\"Transition probability posterior\",\n    )\n    ax.legend()\n\n    plt.savefig(\"hmm_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Semi-supervised Hidden Markov Model\")\n    parser.add_argument(\"--num-categories\", default=3, type=int)\n    parser.add_argument(\"--num-words\", default=10, type=int)\n    parser.add_argument(\"--num-supervised\", default=100, type=int)\n    parser.add_argument(\"--num-unsupervised\", default=500, type=int)\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=500, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--unroll-loop\", action=\"store_true\")\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}