{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: ProdLDA with Flax and Haiku\n\nIn this example, we will follow [1] to implement the ProdLDA topic model from\nAutoencoding Variational Inference For Topic Models by Akash Srivastava and Charles\nSutton [2]. This model returns consistently better topics than vanilla LDA and trains\nmuch more quickly. Furthermore, it does not require a custom inference algorithm that\nrelies on complex mathematical derivations. This example also serves as an\nintroduction to Flax and Haiku modules in NumPyro.\n\nNote that unlike [1, 2], this implementation uses a Dirichlet prior directly rather\nthan approximating it with a softmax-normal distribution.\n\nFor the interested reader, a nice extension of this model is the CombinedTM model [3]\nwhich utilizes a pre-trained sentence transformer (like https://www.sbert.net/) to\ngenerate a better representation of the encoded latent vector.\n\n**References:**\n    1. http://pyro.ai/examples/prodlda.html\n    2. Akash Srivastava, & Charles Sutton. (2017). Autoencoding Variational Inference\n       For Topic Models.\n    3. Federico Bianchi, Silvia Terragni, and Dirk Hovy (2021), \"Pre-training is a Hot\n       Topic: Contextualized Document Embeddings Improve Topic Coherence\"\n       (https://arxiv.org/abs/2004.03974)\n\n<img src=\"file://../_static/img/examples/prodlda.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\n\nimport matplotlib.pyplot as plt\nimport pandas as pd\nfrom sklearn.datasets import fetch_20newsgroups\nfrom sklearn.feature_extraction.text import CountVectorizer\nfrom wordcloud import WordCloud\n\nimport flax.linen as nn\nimport haiku as hk\nimport jax\nfrom jax import device_put, random\nimport jax.numpy as jnp\n\nimport numpyro\nfrom numpyro.contrib.module import flax_module, haiku_module\nimport numpyro.distributions as dist\nfrom numpyro.infer import SVI, TraceMeanField_ELBO\n\n\nclass HaikuEncoder:\n    def __init__(self, vocab_size, num_topics, hidden, dropout_rate):\n        self._vocab_size = vocab_size\n        self._num_topics = num_topics\n        self._hidden = hidden\n        self._dropout_rate = dropout_rate\n\n    def __call__(self, inputs, is_training):\n        dropout_rate = self._dropout_rate if is_training else 0.0\n\n        h = jax.nn.softplus(hk.Linear(self._hidden)(inputs))\n        h = jax.nn.softplus(hk.Linear(self._hidden)(h))\n        h = hk.dropout(hk.next_rng_key(), dropout_rate, h)\n        h = hk.Linear(self._num_topics)(h)\n\n        # NB: here we set `create_scale=False` and `create_offset=False` to reduce\n        # the number of learning parameters\n        log_concentration = hk.BatchNorm(\n            create_scale=False, create_offset=False, decay_rate=0.9\n        )(h, is_training)\n        return jnp.exp(log_concentration)\n\n\nclass HaikuDecoder:\n    def __init__(self, vocab_size, dropout_rate):\n        self._vocab_size = vocab_size\n        self._dropout_rate = dropout_rate\n\n    def __call__(self, inputs, is_training):\n        dropout_rate = self._dropout_rate if is_training else 0.0\n        h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs)\n        h = hk.Linear(self._vocab_size, with_bias=False)(h)\n        return hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(\n            h, is_training\n        )\n\n\nclass FlaxEncoder(nn.Module):\n    vocab_size: int\n    num_topics: int\n    hidden: int\n    dropout_rate: float\n\n    @nn.compact\n    def __call__(self, inputs, is_training):\n        h = nn.softplus(nn.Dense(self.hidden)(inputs))\n        h = nn.softplus(nn.Dense(self.hidden)(h))\n        h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h)\n        h = nn.Dense(self.num_topics)(h)\n\n        log_concentration = nn.BatchNorm(\n            use_bias=False,\n            use_scale=False,\n            momentum=0.9,\n            use_running_average=not is_training,\n        )(h)\n        return jnp.exp(log_concentration)\n\n\nclass FlaxDecoder(nn.Module):\n    vocab_size: int\n    dropout_rate: float\n\n    @nn.compact\n    def __call__(self, inputs, is_training):\n        h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(inputs)\n        h = nn.Dense(self.vocab_size, use_bias=False)(h)\n        return nn.BatchNorm(\n            use_bias=False,\n            use_scale=False,\n            momentum=0.9,\n            use_running_average=not is_training,\n        )(h)\n\n\ndef model(docs, hyperparams, is_training=False, nn_framework=\"flax\"):\n    if nn_framework == \"flax\":\n        decoder = flax_module(\n            \"decoder\",\n            FlaxDecoder(hyperparams[\"vocab_size\"], hyperparams[\"dropout_rate\"]),\n            input_shape=(1, hyperparams[\"num_topics\"]),\n            # ensure PRNGKey is made available to dropout layers\n            apply_rng=[\"dropout\"],\n            # indicate mutable state due to BatchNorm layers\n            mutable=[\"batch_stats\"],\n            # to ensure proper initialisation of BatchNorm we must\n            # initialise with is_training=True\n            is_training=True,\n        )\n    elif nn_framework == \"haiku\":\n        decoder = haiku_module(\n            \"decoder\",\n            # use `transform_with_state` for BatchNorm\n            hk.transform_with_state(\n                HaikuDecoder(hyperparams[\"vocab_size\"], hyperparams[\"dropout_rate\"])\n            ),\n            input_shape=(1, hyperparams[\"num_topics\"]),\n            apply_rng=True,\n            # to ensure proper initialisation of BatchNorm we must\n            # initialise with is_training=True\n            is_training=True,\n        )\n    else:\n        raise ValueError(f\"Invalid choice {nn_framework} for argument nn_framework\")\n\n    with numpyro.plate(\n        \"documents\", docs.shape[0], subsample_size=hyperparams[\"batch_size\"]\n    ):\n        batch_docs = numpyro.subsample(docs, event_dim=1)\n        theta = numpyro.sample(\n            \"theta\", dist.Dirichlet(jnp.ones(hyperparams[\"num_topics\"]))\n        )\n\n        if nn_framework == \"flax\":\n            logits = decoder(theta, is_training, rngs={\"dropout\": numpyro.prng_key()})\n        elif nn_framework == \"haiku\":\n            logits = decoder(numpyro.prng_key(), theta, is_training)\n\n        total_count = batch_docs.sum(-1)\n        numpyro.sample(\n            \"obs\", dist.Multinomial(total_count, logits=logits), obs=batch_docs\n        )\n\n\ndef guide(docs, hyperparams, is_training=False, nn_framework=\"flax\"):\n    if nn_framework == \"flax\":\n        encoder = flax_module(\n            \"encoder\",\n            FlaxEncoder(\n                hyperparams[\"vocab_size\"],\n                hyperparams[\"num_topics\"],\n                hyperparams[\"hidden\"],\n                hyperparams[\"dropout_rate\"],\n            ),\n            input_shape=(1, hyperparams[\"vocab_size\"]),\n            # ensure PRNGKey is made available to dropout layers\n            apply_rng=[\"dropout\"],\n            # indicate mutable state due to BatchNorm layers\n            mutable=[\"batch_stats\"],\n            # to ensure proper initialisation of BatchNorm we must\n            # initialise with is_training=True\n            is_training=True,\n        )\n    elif nn_framework == \"haiku\":\n        encoder = haiku_module(\n            \"encoder\",\n            # use `transform_with_state` for BatchNorm\n            hk.transform_with_state(\n                HaikuEncoder(\n                    hyperparams[\"vocab_size\"],\n                    hyperparams[\"num_topics\"],\n                    hyperparams[\"hidden\"],\n                    hyperparams[\"dropout_rate\"],\n                )\n            ),\n            input_shape=(1, hyperparams[\"vocab_size\"]),\n            apply_rng=True,\n            # to ensure proper initialisation of BatchNorm we must\n            # initialise with is_training=True\n            is_training=True,\n        )\n    else:\n        raise ValueError(f\"Invalid choice {nn_framework} for argument nn_framework\")\n\n    with numpyro.plate(\n        \"documents\", docs.shape[0], subsample_size=hyperparams[\"batch_size\"]\n    ):\n        batch_docs = numpyro.subsample(docs, event_dim=1)\n\n        if nn_framework == \"flax\":\n            concentration = encoder(\n                batch_docs, is_training, rngs={\"dropout\": numpyro.prng_key()}\n            )\n        elif nn_framework == \"haiku\":\n            concentration = encoder(numpyro.prng_key(), batch_docs, is_training)\n\n        numpyro.sample(\"theta\", dist.Dirichlet(concentration))\n\n\ndef load_data():\n    news = fetch_20newsgroups(subset=\"all\")\n    vectorizer = CountVectorizer(max_df=0.5, min_df=20, stop_words=\"english\")\n    docs = jnp.array(vectorizer.fit_transform(news[\"data\"]).toarray())\n\n    vocab = pd.DataFrame(columns=[\"word\", \"index\"])\n    vocab[\"word\"] = vectorizer.get_feature_names_out()\n    vocab[\"index\"] = vocab.index\n\n    return docs, vocab\n\n\ndef run_inference(docs, args):\n    rng_key = random.PRNGKey(0)\n    docs = device_put(docs)\n\n    hyperparams = dict(\n        vocab_size=docs.shape[1],\n        num_topics=args.num_topics,\n        hidden=args.hidden,\n        dropout_rate=args.dropout_rate,\n        batch_size=args.batch_size,\n    )\n\n    optimizer = numpyro.optim.Adam(args.learning_rate)\n    svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO())\n\n    return svi.run(\n        rng_key,\n        args.num_steps,\n        docs,\n        hyperparams,\n        is_training=True,\n        progress_bar=not args.disable_progbar,\n        nn_framework=args.nn_framework,\n    )\n\n\ndef plot_word_cloud(b, ax, vocab, n):\n    indices = jnp.argsort(b)[::-1]\n    top20 = indices[:20]\n    df = pd.DataFrame(top20, columns=[\"index\"])\n    words = pd.merge(df, vocab[[\"index\", \"word\"]], how=\"left\", on=\"index\")[\n        \"word\"\n    ].values.tolist()\n    sizes = b[top20].tolist()\n    freqs = {words[i]: sizes[i] for i in range(len(words))}\n    wc = WordCloud(background_color=\"white\", width=800, height=500)\n    wc = wc.generate_from_frequencies(freqs)\n    ax.set_title(f\"Topic {n + 1}\")\n    ax.imshow(wc, interpolation=\"bilinear\")\n    ax.axis(\"off\")\n\n\ndef main(args):\n    docs, vocab = load_data()\n    print(f\"Dictionary size: {len(vocab)}\")\n    print(f\"Corpus size: {docs.shape}\")\n\n    svi_result = run_inference(docs, args)\n\n    if args.nn_framework == \"flax\":\n        beta = svi_result.params[\"decoder$params\"][\"Dense_0\"][\"kernel\"]\n    elif args.nn_framework == \"haiku\":\n        beta = svi_result.params[\"decoder$params\"][\"linear\"][\"w\"]\n\n    beta = jax.nn.softmax(beta)\n\n    # the number of plots depends on the chosen number of topics.\n    # add 2 to num topics to ensure we create a row for any remainder after division\n    nrows = (args.num_topics + 2) // 3\n    fig, axs = plt.subplots(nrows, 3, figsize=(14, 3 + 3 * nrows))\n    axs = axs.flatten()\n\n    for n in range(beta.shape[0]):\n        plot_word_cloud(beta[n], axs[n], vocab, n)\n\n    # hide any unused axes\n    for i in range(n, len(axs)):\n        axs[i].axis(\"off\")\n\n    fig.savefig(\"wordclouds.png\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(\n        description=\"Probabilistic topic modelling with Flax and Haiku\"\n    )\n    parser.add_argument(\"-n\", \"--num-steps\", nargs=\"?\", default=30_000, type=int)\n    parser.add_argument(\"-t\", \"--num-topics\", nargs=\"?\", default=12, type=int)\n    parser.add_argument(\"--batch-size\", nargs=\"?\", default=32, type=int)\n    parser.add_argument(\"--learning-rate\", nargs=\"?\", default=1e-3, type=float)\n    parser.add_argument(\"--hidden\", nargs=\"?\", default=200, type=int)\n    parser.add_argument(\"--dropout-rate\", nargs=\"?\", default=0.2, type=float)\n    parser.add_argument(\n        \"-dp\",\n        \"--disable-progbar\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to disable progress bar\",\n    )\n    parser.add_argument(\n        \"--device\", default=\"cpu\", type=str, help='use \"cpu\", \"gpu\" or \"tpu\".'\n    )\n    parser.add_argument(\n        \"--nn-framework\",\n        nargs=\"?\",\n        default=\"flax\",\n        help=(\n            \"The framework to use for constructing encoder / decoder. Options are \"\n            '\"flax\" or \"haiku\".'\n        ),\n    )\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}