{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "kS-pdbnl7X3X",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "905a68b8-2261-42e6-adbd-4b433f6a0f9b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fatal: destination path 'Multi-environment-Topic-Models' already exists and is not an empty directory.\n"
          ]
        }
      ],
      "source": [
        "!git clone https://github.com/anonymousindividual007/Multi-environment-Topic-Models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "zlhkhCeBVIIL",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "01a470f0-0ce7-4f5a-fe74-9e8b688c218c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
            "[nltk_data]   Package punkt is already up-to-date!\n"
          ]
        }
      ],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import itertools as it\n",
        "import math\n",
        "import csv\n",
        "import os\n",
        "import zipfile\n",
        "\n",
        "import torch\n",
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch.distributions as dist\n",
        "from torch.utils.data import TensorDataset, DataLoader\n",
        "from torch.distributions import Normal, Distribution, HalfCauchy, Laplace\n",
        "\n",
        "import nltk\n",
        "nltk.download('punkt')\n",
        "from collections import Counter\n",
        "from nltk.stem import WordNetLemmatizer\n",
        "from nltk.tokenize import word_tokenize\n",
        "from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer\n",
        "from scipy.sparse import csr_matrix\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "\n",
        "from gensim.corpora import Dictionary\n",
        "from gensim.models.coherencemodel import CoherenceModel\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "_mKIpYH11tph"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Lptxo8yliDv"
      },
      "source": [
        "The political_stopwords.txt is used for preprocessing in all of our experiments.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "A4ebiLxw7i64"
      },
      "outputs": [],
      "source": [
        "file_path = \"/content/Multi-environment-Topic-Models/political_stopwords.txt\"\n",
        "\n",
        "with open(file_path, 'r') as file:\n",
        "    stopwords_list = file.readlines()\n",
        "\n",
        "all_stopwords = [word.strip() for word in stopwords_list]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "GKibgGW97j9V"
      },
      "outputs": [],
      "source": [
        "class LemmaTokenizer:\n",
        "\tdef __init__(self):\n",
        "\t\tself.wnl = WordNetLemmatizer()\n",
        "\tdef __call__(self, doc):\n",
        "\t\treturn [t for t in word_tokenize(doc) if str.isalpha(t)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uroxXKYoltC2"
      },
      "source": [
        "To use your own data replace the file path. Ensure there is a column called 'source' which indicates the environments of your dataset. In the cell below represents the data is for the Political Advertisements experiment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "bNC5QaY_ldZJ"
      },
      "outputs": [],
      "source": [
        "# file_path = '/content/Multi-environment-Topic-Models/local_channels.csv'\n",
        "\n",
        "# train_data = pd.read_csv(file_path)\n",
        "\n",
        "# test1 = train_data[train_data['source'] == 'right'].sample(frac=0.2, random_state=42)\n",
        "# test2 = train_data[train_data['source'] == 'left'].sample(frac=0.2, random_state=42)\n",
        "\n",
        "# # Drop the sampled rows from train_data\n",
        "# train_data = train_data.drop(test1.index)\n",
        "# train_data = train_data.drop(test2.index)\n",
        "\n",
        "# npmi_text=pd.concat([test1, test2])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sAuUYkMXl2Rm"
      },
      "source": [
        "The data in the cell below is for the ideology dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {
        "id": "F3sELKmGl07S"
      },
      "outputs": [],
      "source": [
        "train_data= pd.read_csv('/content/Multi-environment-Topic-Models/channels_ideology_train.csv')\n",
        "channels_ideology_test = pd.read_csv('/content/Multi-environment-Topic-Models/channels_ideology_test.csv')\n",
        "test1 = channels_ideology_test[channels_ideology_test['source'] == 'Republican']\n",
        "test2 = channels_ideology_test[channels_ideology_test['source'] == 'Democratic']\n",
        "test3 = channels_ideology_test[channels_ideology_test['source'] == 'balanced']\n",
        "\n",
        "#npmi_test=test3"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iGuWi9vomC5O"
      },
      "source": [
        "The code below represents the preprocessing for the Style dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "XZvm-ekOl8ID"
      },
      "outputs": [],
      "source": [
        "# # Specify the path to the zip file and the name of the CSV file inside it\n",
        "# zip_file_path = '/content/Multi-environment-Topic-Models/style_train_large.csv.zip'\n",
        "# csv_file_name = 'style_train_large.csv'  # Change this if the CSV file has a different name inside the zip\n",
        "\n",
        "# # Specify the temporary directory to extract the CSV file\n",
        "# temp_dir = '/content/temp_dir'\n",
        "\n",
        "# # Create a temporary directory if it doesn't exist\n",
        "# if not os.path.exists(temp_dir):\n",
        "#     os.makedirs(temp_dir)\n",
        "\n",
        "# # Extract the CSV file\n",
        "# with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:\n",
        "#     zip_ref.extract(csv_file_name, temp_dir)\n",
        "\n",
        "# # Full path to the extracted CSV file\n",
        "# csv_file_path = os.path.join(temp_dir, csv_file_name)\n",
        "\n",
        "# # Load the CSV file into a Pandas DataFrame\n",
        "# train_data = pd.read_csv(csv_file_path, encoding='ISO-8859-1')\n",
        "# style_test_df = pd.read_csv('/content/Multi-environment-Topic-Models/style_test.csv', encoding='ISO-8859-1')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "id": "RC0bsas37pSQ"
      },
      "outputs": [],
      "source": [
        "# # # 2. Map the 'source' values to environments\n",
        "# env_map = {'articles': 'env_0', 'speeches': 'env_1', 'tweets': 'env_2'}\n",
        "# style_test_df['source'] = style_test_df['source'].map(env_map)\n",
        "\n",
        "# # Count number of 'articles' and 'speeches' in train_data\n",
        "# num_articles = len(train_data[train_data['source'] == 'articles'])\n",
        "# num_speeches = len(train_data[train_data['source'] == 'speeches'])\n",
        "# num_tweets = len(train_data[train_data['source'] == 'tweets'])\n",
        "\n",
        "\n",
        "# # Determine the lesser count\n",
        "# min_count = min(num_articles, num_speeches, num_tweets)\n",
        "\n",
        "# # Randomly sample that many from both sources\n",
        "# sampled_articles = train_data[train_data['source'] == 'articles'].sample(min_count, random_state=42)\n",
        "# sampled_speeches = train_data[train_data['source'] == 'speeches'].sample(min_count, random_state=42)\n",
        "# sampled_tweets = train_data[train_data['source'] == 'tweets'].sample(min_count, random_state=42)\n",
        "\n",
        "# # Combine the two sampled dataframes to create a balanced train_data\n",
        "# # train_data = pd.concat([sampled_articles, sampled_speeches, sampled_tweets], ignore_index=True)\n",
        "\n",
        "# #call it combined for the ood test\n",
        "# # combined_data = pd.concat([sampled_articles, sampled_speeches, sampled_tweets], ignore_index=True)\n",
        "\n",
        "# # # no tweets, but test on tweets\n",
        "# train_data = pd.concat([sampled_articles, sampled_speeches], ignore_index=True)\n",
        "\n",
        "# # Now, map the 'source' values to the environments (assuming env_map is already defined)\n",
        "\n",
        "# train_data['source'] = train_data['source'].map(env_map)\n",
        "\n",
        "# test1 = style_test_df[style_test_df['source'] == 'env_0']\n",
        "# test2 = style_test_df[style_test_df['source'] == 'env_1']\n",
        "# test3 = style_test_df[style_test_df['source'] == 'env_2']\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_data['source'].value_counts()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 178
        },
        "id": "wrSJkK4c7CW0",
        "outputId": "605426e4-c607-48de-80a3-deb0753d98a5"
      },
      "execution_count": 46,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "source\n",
              "Republican    12941\n",
              "Democratic    12941\n",
              "Name: count, dtype: int64"
            ],
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>count</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>source</th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>Republican</th>\n",
              "      <td>12941</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Democratic</th>\n",
              "      <td>12941</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div><br><label><b>dtype:</b> int64</label>"
            ]
          },
          "metadata": {},
          "execution_count": 46
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Creating environment covariates"
      ],
      "metadata": {
        "id": "AWZ9F3Apw62f"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "env_mapping = {value: index for index, value in enumerate(train_data['source'].unique())}\n",
        "\n",
        "num_docs = len(train_data)\n",
        "num_envs = len(env_mapping)\n",
        "env_index_matrix = np.zeros((num_docs, num_envs), dtype=int)\n",
        "\n",
        "for doc_idx, source in enumerate(train_data['source']):\n",
        "    env_idx = env_mapping[source]\n",
        "    env_index_matrix[doc_idx, env_idx] = 1\n",
        "\n",
        "env_index_tensor = torch.from_numpy(env_index_matrix).float().to(device)"
      ],
      "metadata": {
        "id": "1npfJ3ozw2rU"
      },
      "execution_count": 47,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Preprocessing the ideology and channels dataset"
      ],
      "metadata": {
        "id": "cu-mibzmxHP9"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 51,
      "metadata": {
        "id": "H954flOWmLwU"
      },
      "outputs": [],
      "source": [
        "vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(), ngram_range=(1, 1), stop_words=all_stopwords, max_df=0.4, min_df=0.0006)\n",
        "\n",
        "docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "On comment the code below to preprocess the IID style data"
      ],
      "metadata": {
        "id": "SX4NsF_bwkIh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#style tok iid\n",
        "# vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(), ngram_range=(1, 1), stop_words=all_stopwords, max_df=0.5, min_df=0.006)\n",
        "\n",
        "\n",
        "# docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "# docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)"
      ],
      "metadata": {
        "id": "nHEeIADfwiqa"
      },
      "execution_count": 52,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DguMJNJ3tSTJ"
      },
      "source": [
        "On comment the code below for preprocessing OOD data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 53,
      "metadata": {
        "id": "eIYXpM5gmWmM"
      },
      "outputs": [],
      "source": [
        "# vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(),\n",
        "#                              ngram_range=(1, 1),\n",
        "#                              stop_words=all_stopwords,\n",
        "#                              max_df=0.5,\n",
        "#                              min_df=0.008)\n",
        "\n",
        "\n",
        "# docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "# docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Uncomment the parameters learning with EB based on the experiment you would like to replicate. Hyperparameters can also be found in appendix B"
      ],
      "metadata": {
        "id": "ahO5vVKhsLBv"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 54,
      "metadata": {
        "id": "wCaQoRRBVGJm"
      },
      "outputs": [],
      "source": [
        "class EnvTM(nn.Module):\n",
        "    def __init__(self, num_topics, num_words, num_envs, device='cpu', empirical_bayes=True):\n",
        "        super(EnvTM, self).__init__()\n",
        "\n",
        "        def init_param(shape):\n",
        "            return nn.Parameter(torch.randn(shape, device=device))\n",
        "\n",
        "        def init_param_zeros(shape):\n",
        "            return nn.Parameter(torch.zeros(shape, device=device))\n",
        "\n",
        "        self.num_topics, self.num_words, self.num_envs = num_topics, num_words, num_envs\n",
        "\n",
        "        self.beta = init_param([num_topics, num_words])\n",
        "        self.beta_logvar = init_param_zeros([num_topics, num_words])\n",
        "        self.beta_prior = Normal(torch.zeros([num_topics, num_words], device=device), torch.ones([num_topics, num_words], device=device))\n",
        "\n",
        "        if empirical_bayes:\n",
        "            self.log_alpha_a = nn.Parameter(torch.tensor(1.0, device=device))\n",
        "            self.log_alpha_b = nn.Parameter(torch.tensor(1.0, device=device))\n",
        "        else:\n",
        "            # alpha_a_fixed = torch.tensor(4.0, device=device) #Ideology dataset\n",
        "            # alpha_b_fixed = torch.tensor(0.11, device=device) #Ideology dataset\n",
        "\n",
        "            # alpha_a_fixed = torch.tensor(3.8, device=device) #Channels\n",
        "            # alpha_b_fixed = torch.tensor(0.13, device=device) #Channels\n",
        "\n",
        "            # alpha_a_fixed = torch.tensor(3.7, device=device) #IID Style\n",
        "            # alpha_b_fixed = torch.tensor(0.34, device=device) #IID Style\n",
        "\n",
        "            # alpha_a_fixed = torch.tensor(2.87, device=device) #Style tr:ads, articles\n",
        "            # alpha_b_fixed = torch.tensor(0.25, device=device) #Style tr:ads, articles\n",
        "\n",
        "            alpha_a_fixed = torch.tensor(2.92, device=device) #Style tr: speeches,articles\n",
        "            alpha_b_fixed = torch.tensor(0.25, device=device) #Style tr: speeches,articles\n",
        "\n",
        "            self.log_alpha_a = alpha_a_fixed\n",
        "            self.log_alpha_b = alpha_b_fixed\n",
        "\n",
        "        self.sigma = torch.distributions.Gamma(torch.exp(self.log_alpha_a), torch.exp(self.log_alpha_b)).rsample([num_envs, num_topics, num_words])\n",
        "\n",
        "        # Initialize gamma with variance given by the inverse of sigma\n",
        "        self.gamma = init_param_zeros([num_envs, num_topics, num_words])\n",
        "        self.gamma_logvar = -torch.log(self.sigma).add(1e-8)\n",
        "        self.gamma_prior = Normal(torch.zeros_like(self.gamma), torch.sqrt(1.0/self.sigma).add(1e-8))\n",
        "\n",
        "        # Global Theta, θ_{d} ~ 𝒩(·,·)\n",
        "        self.theta_global_prior = Normal(torch.zeros(num_topics, device=device), torch.ones(num_topics, device=device))\n",
        "\n",
        "        self.theta_global_net = nn.Sequential(\n",
        "            nn.Linear(num_words, 50),\n",
        "            nn.BatchNorm1d(50),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(50, num_topics * 2)\n",
        "        )\n",
        "\n",
        "\n",
        "    def forward(self, bow, x_d):\n",
        "        batch_size, vocab_size = bow.size()\n",
        "\n",
        "        self.theta_global_params = self.theta_global_net(bow)\n",
        "        theta_global_mu, theta_global_logvar = self.theta_global_params.split(self.num_topics, dim=-1)\n",
        "        theta_global_logvar = theta_global_logvar.add(1e-8)\n",
        "        theta_sample = Normal(theta_global_mu, torch.exp(0.5 * theta_global_logvar).add(1e-8)).rsample()\n",
        "        theta_softmax = F.softmax(theta_sample, dim=-1)\n",
        "\n",
        "        beta_dist = Normal(self.beta, torch.exp(0.5 * self.beta_logvar).add(1e-8))\n",
        "        beta_sample = beta_dist.rsample()\n",
        "\n",
        "        gamma_dist = Normal(self.gamma, torch.exp(0.5 * self.gamma_logvar).add(1e-8))\n",
        "        gamma_sample = gamma_dist.rsample()\n",
        "        gamma_effect = torch.einsum('be,etv->btv', x_d, gamma_sample)\n",
        "\n",
        "        adjusted_beta = self.beta.unsqueeze(0) + gamma_effect\n",
        "        adjusted_beta_softmax = F.softmax(adjusted_beta, dim=-1)\n",
        "        eta_d = torch.einsum('bt,btv->bv', theta_softmax, adjusted_beta_softmax)\n",
        "\n",
        "        return eta_d"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 55,
      "metadata": {
        "id": "vZaYTDeR3qdD"
      },
      "outputs": [],
      "source": [
        "def calculate_kl_divergences(EnvTM, env, empirical_bayes=True):\n",
        "    theta_global_mu, theta_global_logvar = EnvTM.theta_global_params.split(EnvTM.num_topics, dim=-1)\n",
        "    theta_global_logvar = theta_global_logvar.add(1e-8)\n",
        "    theta_global = Normal(theta_global_mu, torch.exp(0.5 * theta_global_logvar).add(1e-8))\n",
        "    theta_global_kl = torch.distributions.kl.kl_divergence(theta_global, EnvTM.theta_global_prior).sum()\n",
        "\n",
        "    beta = Normal(EnvTM.beta, torch.exp(0.5 * EnvTM.beta_logvar))\n",
        "    beta_kl = torch.distributions.kl.kl_divergence(beta, EnvTM.beta_prior).sum()\n",
        "\n",
        "    if not empirical_bayes:\n",
        "        gamma = Normal(EnvTM.gamma, torch.exp(0.5 * EnvTM.gamma_logvar))\n",
        "        gamma_kl = torch.distributions.kl.kl_divergence(gamma, EnvTM.gamma_prior).sum()\n",
        "    else:\n",
        "        gamma_kl = 0\n",
        "\n",
        "    return theta_global_kl, beta_kl, gamma_kl\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 56,
      "metadata": {
        "id": "32MtSHiD3saH"
      },
      "outputs": [],
      "source": [
        "def bbvi_update(minibatch, env_index, EnvTM, optimizer, n_samples):\n",
        "    optimizer.zero_grad()\n",
        "    elbo_accumulator = torch.zeros(1, device=minibatch.device)\n",
        "\n",
        "    z = EnvTM(minibatch, env_index)\n",
        "\n",
        "    kl_theta, kl_beta, kl_gamma = calculate_kl_divergences(env_tm_model, env_index, empirical_bayes=False)\n",
        "\n",
        "    elbo = (minibatch * z.log()).sum(-1).mul(n_samples).sub(kl_theta + kl_beta + kl_gamma)\n",
        "    elbo_accumulator += elbo.sum()\n",
        "\n",
        "    (-elbo_accumulator).backward(retain_graph=True)\n",
        "    optimizer.step()\n",
        "\n",
        "    return elbo_accumulator.item()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 57,
      "metadata": {
        "id": "YIVo7YjjNDyr"
      },
      "outputs": [],
      "source": [
        "def empirical_bayes_update(EnvTM, optimizer_hyper, empirical_bayes=True, num_epochs_hyper=2, kl_threshold=1e-5):\n",
        "    \"\"\"Empirical Bayes update for the hyperparameters of the Gamma distribution.\"\"\"\n",
        "\n",
        "    if not empirical_bayes:\n",
        "        EnvTM.log_alpha_a = torch.log(torch.tensor(3.1, device=EnvTM.log_alpha_a.device) - 1)\n",
        "        EnvTM.log_alpha_b = torch.log(torch.tensor(0.29, device=EnvTM.log_alpha_b.device) - 1)\n",
        "        return\n",
        "\n",
        "    previous_gamma_kl = float('inf')\n",
        "\n",
        "    for _ in range(num_epochs_hyper):\n",
        "        optimizer_hyper.zero_grad()\n",
        "\n",
        "        sigma_sample = torch.distributions.Gamma(torch.nn.functional.softplus(EnvTM.log_alpha_a), torch.nn.functional.softplus(EnvTM.log_alpha_b)).rsample([EnvTM.num_envs, EnvTM.num_topics, EnvTM.num_words])\n",
        "\n",
        "        gamma_prior = Normal(torch.zeros_like(EnvTM.gamma), torch.sqrt(1.0/sigma_sample).add(1e-8))\n",
        "\n",
        "        gamma = Normal(EnvTM.gamma, torch.exp(0.5 * EnvTM.gamma_logvar))\n",
        "        gamma_kl = torch.distributions.kl.kl_divergence(gamma, gamma_prior).sum()\n",
        "\n",
        "        delta_gamma_kl = torch.abs(gamma_kl - previous_gamma_kl).item()\n",
        "\n",
        "        if delta_gamma_kl < kl_threshold:\n",
        "            print(\"Early stopping of hyperparameter updates based on gamma KL divergence stability.\")\n",
        "            break\n",
        "\n",
        "        (-gamma_kl).backward(retain_graph=True)\n",
        "        optimizer_hyper.step()\n",
        "\n",
        "        previous_gamma_kl = gamma_kl.item()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 58,
      "metadata": {
        "id": "NUaAWQtzKN_u"
      },
      "outputs": [],
      "source": [
        "def train_model(EnvTM, docs_word_matrix_tensor, env_index_tensor, num_epochs=80, minibatch_size=16, lr=0.01, empirical_bayes=True):\n",
        "    EnvTM = EnvTM.to(device)\n",
        "    optimizer = torch.optim.Adam(EnvTM.parameters(), lr=lr, betas=(0.9, 0.999))\n",
        "    optimizer_hyper = torch.optim.Adam([EnvTM.log_alpha_a, EnvTM.log_alpha_b], lr=lr, betas=(0.9, 0.999))\n",
        "\n",
        "    docs_word_matrix_tensor = docs_word_matrix_tensor.to(device)\n",
        "    env_index_tensor = env_index_tensor.to(device)\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        elbo_accumulator = 0.0\n",
        "        permutation = torch.randperm(docs_word_matrix_tensor.size()[0])\n",
        "\n",
        "        for i in range(0, docs_word_matrix_tensor.size()[0], minibatch_size):\n",
        "            indices = permutation[i:i+minibatch_size]\n",
        "            minibatch = docs_word_matrix_tensor[indices]\n",
        "            minibatch_env_index = env_index_tensor[indices]\n",
        "\n",
        "            elbo = bbvi_update(minibatch, minibatch_env_index, EnvTM, optimizer, docs_word_matrix_tensor.size()[0])\n",
        "            elbo_accumulator += elbo\n",
        "        if empirical_bayes:\n",
        "            empirical_bayes_update(EnvTM, optimizer_hyper)\n",
        "\n",
        "        avg_elbo = elbo_accumulator / (docs_word_matrix_tensor.size()[0] / minibatch_size)\n",
        "\n",
        "        print(f'Epoch: {epoch+1}, Average ELBO: {avg_elbo}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 59,
      "metadata": {
        "id": "SbiRUzxOqShF"
      },
      "outputs": [],
      "source": [
        "empirical_bayes = False\n",
        "\n",
        "num_topics = 20\n",
        "num_envs = 2\n",
        "\n",
        "if empirical_bayes:\n",
        "    num_epochs = 15\n",
        "else:\n",
        "    num_epochs = 150"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 60,
      "metadata": {
        "id": "6H2e8NcS3FUs",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bcf08865-dbff-4133-a821-69b322987cb6"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 1, Average ELBO: -4733672461.056178\n",
            "Epoch: 2, Average ELBO: -4504445247.678541\n",
            "Epoch: 3, Average ELBO: -4361736366.398888\n",
            "Epoch: 4, Average ELBO: -4257428466.6273084\n",
            "Epoch: 5, Average ELBO: -4176161152.7220464\n",
            "Epoch: 6, Average ELBO: -4112938936.626227\n",
            "Epoch: 7, Average ELBO: -4064822051.9440536\n",
            "Epoch: 8, Average ELBO: -4024826486.0199366\n",
            "Epoch: 9, Average ELBO: -3992870826.54138\n",
            "Epoch: 10, Average ELBO: -3967643649.582567\n",
            "Epoch: 11, Average ELBO: -3947414558.622672\n",
            "Epoch: 12, Average ELBO: -3930140046.886021\n",
            "Epoch: 13, Average ELBO: -3915158808.984777\n",
            "Epoch: 14, Average ELBO: -3902676021.5698943\n",
            "Epoch: 15, Average ELBO: -3890579666.6396723\n",
            "Epoch: 16, Average ELBO: -3880474100.012055\n",
            "Epoch: 17, Average ELBO: -3871449355.7703424\n",
            "Epoch: 18, Average ELBO: -3862639648.2052393\n",
            "Epoch: 19, Average ELBO: -3854298339.9687815\n",
            "Epoch: 20, Average ELBO: -3848262788.8565025\n",
            "Epoch: 21, Average ELBO: -3840645779.772197\n",
            "Epoch: 22, Average ELBO: -3834078774.7963834\n",
            "Epoch: 23, Average ELBO: -3828997225.4780927\n",
            "Epoch: 24, Average ELBO: -3823942217.70806\n",
            "Epoch: 25, Average ELBO: -3819393823.8293796\n",
            "Epoch: 26, Average ELBO: -3815384074.998841\n",
            "Epoch: 27, Average ELBO: -3812825208.4333515\n",
            "Epoch: 28, Average ELBO: -3809239384.999614\n",
            "Epoch: 29, Average ELBO: -3806480560.4562244\n",
            "Epoch: 30, Average ELBO: -3803862888.2318215\n",
            "Epoch: 31, Average ELBO: -3800885912.519898\n",
            "Epoch: 32, Average ELBO: -3798836029.8981533\n",
            "Epoch: 33, Average ELBO: -3796950389.881462\n",
            "Epoch: 34, Average ELBO: -3794878272.825902\n",
            "Epoch: 35, Average ELBO: -3792909536.6453905\n",
            "Epoch: 36, Average ELBO: -3792455879.2451897\n",
            "Epoch: 37, Average ELBO: -3790082861.7559695\n",
            "Epoch: 38, Average ELBO: -3788652021.1198516\n",
            "Epoch: 39, Average ELBO: -3787274017.965845\n",
            "Epoch: 40, Average ELBO: -3786074718.281431\n",
            "Epoch: 41, Average ELBO: -3785014231.248899\n",
            "Epoch: 42, Average ELBO: -3783493667.9242716\n",
            "Epoch: 43, Average ELBO: -3782291079.032532\n",
            "Epoch: 44, Average ELBO: -3781212024.057492\n",
            "Epoch: 45, Average ELBO: -3780266266.804729\n",
            "Epoch: 46, Average ELBO: -3780052004.9529405\n",
            "Epoch: 47, Average ELBO: -3778479452.402133\n",
            "Epoch: 48, Average ELBO: -3777488294.46627\n",
            "Epoch: 49, Average ELBO: -3777016567.9486904\n",
            "Epoch: 50, Average ELBO: -3775885663.013368\n",
            "Epoch: 51, Average ELBO: -3775735402.704582\n",
            "Epoch: 52, Average ELBO: -3774834782.7957654\n",
            "Epoch: 53, Average ELBO: -3774515691.703578\n",
            "Epoch: 54, Average ELBO: -3773062006.3957963\n",
            "Epoch: 55, Average ELBO: -3773027101.297272\n",
            "Epoch: 56, Average ELBO: -3771756239.197589\n",
            "Epoch: 57, Average ELBO: -3770901825.1819797\n",
            "Epoch: 58, Average ELBO: -3770624974.703037\n",
            "Epoch: 59, Average ELBO: -3770689910.593617\n",
            "Epoch: 60, Average ELBO: -3770427905.859516\n",
            "Epoch: 61, Average ELBO: -3769656522.172939\n",
            "Epoch: 62, Average ELBO: -3769328948.5214434\n",
            "Epoch: 63, Average ELBO: -3768623982.008191\n",
            "Epoch: 64, Average ELBO: -3768276437.6267676\n",
            "Epoch: 65, Average ELBO: -3767026290.696546\n",
            "Epoch: 66, Average ELBO: -3766862230.4032145\n",
            "Epoch: 67, Average ELBO: -3766199739.9891815\n",
            "Epoch: 68, Average ELBO: -3766257370.7503285\n",
            "Epoch: 69, Average ELBO: -3765403863.229117\n",
            "Epoch: 70, Average ELBO: -3765067455.055405\n",
            "Epoch: 71, Average ELBO: -3764544850.9067307\n",
            "Epoch: 72, Average ELBO: -3763914437.2274165\n",
            "Epoch: 73, Average ELBO: -3763763029.874044\n",
            "Epoch: 74, Average ELBO: -3762978772.9541764\n",
            "Epoch: 75, Average ELBO: -3762925449.307472\n",
            "Epoch: 76, Average ELBO: -3762339514.6242175\n",
            "Epoch: 77, Average ELBO: -3761871679.95549\n",
            "Epoch: 78, Average ELBO: -3761780063.2507534\n",
            "Epoch: 79, Average ELBO: -3761645226.0072637\n",
            "Epoch: 80, Average ELBO: -3761334131.3097906\n",
            "Epoch: 81, Average ELBO: -3760643220.5239162\n",
            "Epoch: 82, Average ELBO: -3760544450.616181\n",
            "Epoch: 83, Average ELBO: -3760501295.912217\n",
            "Epoch: 84, Average ELBO: -3759347371.589831\n",
            "Epoch: 85, Average ELBO: -3759839874.0474463\n",
            "Epoch: 86, Average ELBO: -3759993018.426397\n",
            "Epoch: 87, Average ELBO: -3759330923.0210958\n",
            "Epoch: 88, Average ELBO: -3759421972.5338073\n",
            "Epoch: 89, Average ELBO: -3758706221.7757516\n",
            "Epoch: 90, Average ELBO: -3758094971.0841513\n",
            "Epoch: 91, Average ELBO: -3758539563.5403757\n",
            "Epoch: 92, Average ELBO: -3757421704.0216365\n",
            "Epoch: 93, Average ELBO: -3757919855.61054\n",
            "Epoch: 94, Average ELBO: -3757669377.0286684\n",
            "Epoch: 95, Average ELBO: -3756926363.0717874\n",
            "Epoch: 96, Average ELBO: -3757533977.657368\n",
            "Epoch: 97, Average ELBO: -3756933526.4032145\n",
            "Epoch: 98, Average ELBO: -3756701724.6444635\n",
            "Epoch: 99, Average ELBO: -3756258590.6820183\n",
            "Epoch: 100, Average ELBO: -3756659919.5536666\n",
            "Epoch: 101, Average ELBO: -3755878156.7594466\n",
            "Epoch: 102, Average ELBO: -3756030457.946681\n",
            "Epoch: 103, Average ELBO: -3755294627.578085\n",
            "Epoch: 104, Average ELBO: -3755739594.7861834\n",
            "Epoch: 105, Average ELBO: -3755297610.4004326\n",
            "Epoch: 106, Average ELBO: -3755326359.9462175\n",
            "Epoch: 107, Average ELBO: -3755682014.667182\n",
            "Epoch: 108, Average ELBO: -3755182115.9242716\n",
            "Epoch: 109, Average ELBO: -3754638614.294413\n",
            "Epoch: 110, Average ELBO: -3753717681.8211884\n",
            "Epoch: 111, Average ELBO: -3754215443.4655747\n",
            "Epoch: 112, Average ELBO: -3754360639.5598483\n",
            "Epoch: 113, Average ELBO: -3754229980.293331\n",
            "Epoch: 114, Average ELBO: -3753881418.004791\n",
            "Epoch: 115, Average ELBO: -3753653346.0400276\n",
            "Epoch: 116, Average ELBO: -3753246038.863148\n",
            "Epoch: 117, Average ELBO: -3753466555.0198593\n",
            "Epoch: 118, Average ELBO: -3753428943.0986786\n",
            "Epoch: 119, Average ELBO: -3752973470.177575\n",
            "Epoch: 120, Average ELBO: -3752896481.456456\n",
            "Epoch: 121, Average ELBO: -3752389913.143034\n",
            "Epoch: 122, Average ELBO: -3753018058.0542464\n",
            "Epoch: 123, Average ELBO: -3752121811.0946603\n",
            "Epoch: 124, Average ELBO: -3751763044.4138784\n",
            "Epoch: 125, Average ELBO: -3751912496.9804497\n",
            "Epoch: 126, Average ELBO: -3751297503.834325\n",
            "Epoch: 127, Average ELBO: -3751479997.86848\n",
            "Epoch: 128, Average ELBO: -3751755807.6513405\n",
            "Epoch: 129, Average ELBO: -3751765196.705046\n",
            "Epoch: 130, Average ELBO: -3751131043.103315\n",
            "Epoch: 131, Average ELBO: -3751113900.7371917\n",
            "Epoch: 132, Average ELBO: -3750280386.734874\n",
            "Epoch: 133, Average ELBO: -3751093907.2182984\n",
            "Epoch: 134, Average ELBO: -3751063068.6840277\n",
            "Epoch: 135, Average ELBO: -3750439514.48327\n",
            "Epoch: 136, Average ELBO: -3750070811.813616\n",
            "Epoch: 137, Average ELBO: -3750672453.1581793\n",
            "Epoch: 138, Average ELBO: -3750121529.9219537\n",
            "Epoch: 139, Average ELBO: -3750301823.554903\n",
            "Epoch: 140, Average ELBO: -3749694892.6382813\n",
            "Epoch: 141, Average ELBO: -3749190651.331427\n",
            "Epoch: 142, Average ELBO: -3749740852.917085\n",
            "Epoch: 143, Average ELBO: -3749694046.914458\n",
            "Epoch: 144, Average ELBO: -3750096707.674523\n",
            "Epoch: 145, Average ELBO: -3749791867.281972\n",
            "Epoch: 146, Average ELBO: -3749203724.60119\n",
            "Epoch: 147, Average ELBO: -3749222857.2036166\n",
            "Epoch: 148, Average ELBO: -3749474211.8945985\n",
            "Epoch: 149, Average ELBO: -3749594469.2645082\n",
            "Epoch: 150, Average ELBO: -3749168021.8888803\n"
          ]
        }
      ],
      "source": [
        "env_tm_model = EnvTM(num_topics=num_topics, num_words=len(vectorizer.get_feature_names_out()), num_envs=num_envs, device=device, empirical_bayes=empirical_bayes)\n",
        "\n",
        "train_model(env_tm_model, docs_word_matrix_tensor, env_index_tensor, num_epochs=num_epochs, minibatch_size=1024, lr=0.01)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 64,
      "metadata": {
        "id": "C9nJfzMfee_Y"
      },
      "outputs": [],
      "source": [
        "def softplus(x):\n",
        "    return math.log(1 + math.exp(x))\n",
        "\n",
        "if not empirical_bayes:\n",
        "    alpha_a = env_tm_model.log_alpha_a.item()\n",
        "    alpha_b = env_tm_model.log_alpha_b.item()\n",
        "\n",
        "else:\n",
        "    alpha_a_softplus = softplus(env_tm_model.log_alpha_a.item())\n",
        "    alpha_b_softplus = softplus(env_tm_model.log_alpha_b.item())\n",
        "    print(f\"After Training (softplus): log_alpha_a = {alpha_a_softplus}, log_alpha_b = {alpha_b_softplus}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 71,
      "metadata": {
        "id": "W6wpUBux3HoP"
      },
      "outputs": [],
      "source": [
        "test_data_word_matrix_raw = vectorizer.transform(test1['text'])\n",
        "test_data_word_matrix_tensor = torch.from_numpy(test_data_word_matrix_raw.toarray()).float().to(device)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 72,
      "metadata": {
        "id": "zDFxIWUV3Hqq"
      },
      "outputs": [],
      "source": [
        "def evaluate_model(env_tm_model, test_data_word_matrix_tensor):\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    env_tm_model.to(device)\n",
        "    env_tm_model.eval()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        theta_test_params = env_tm_model.theta_global_net(test_data_word_matrix_tensor)\n",
        "        theta_test_mu, theta_test_logvar = theta_test_params.split(env_tm_model.num_topics, dim=-1)\n",
        "        theta_test_dist = Normal(theta_test_mu, torch.exp(0.5 * theta_test_logvar).add(1e-8))\n",
        "        theta_test = theta_test_dist.rsample()\n",
        "        theta_test_softmax = F.softmax(theta_test, dim=-1)\n",
        "        beta_test_softmax = F.softmax(env_tm_model.beta.to(device), dim=-1)\n",
        "\n",
        "        likelihood = torch.mm(theta_test_softmax, beta_test_softmax)\n",
        "        N = torch.sum(test_data_word_matrix_tensor)\n",
        "        log_perplex = -torch.sum(torch.log(likelihood) * test_data_word_matrix_tensor) / N\n",
        "        perplexity = torch.exp(log_perplex)\n",
        "\n",
        "    return perplexity, theta_test_softmax\n",
        "\n",
        "def evaluate_model_with_gamma_per_env(env_tm_model, test_data_word_matrix_tensor):\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    env_tm_model.to(device)\n",
        "    env_tm_model.eval()\n",
        "\n",
        "\n",
        "    with torch.no_grad():\n",
        "        theta_test_params = env_tm_model.theta_global_net(test_data_word_matrix_tensor)\n",
        "        theta_test_mu, theta_test_logvar = theta_test_params.split(env_tm_model.num_topics, dim=-1)\n",
        "        theta_test_dist = Normal(theta_test_mu, torch.exp(0.5 * theta_test_logvar).add(1e-8))\n",
        "        theta_test = theta_test_dist.rsample()\n",
        "        theta_test_softmax = F.softmax(theta_test, dim=-1)\n",
        "\n",
        "        gamma_learned = env_tm_model.gamma[0]\n",
        "\n",
        "        beta_gamma_test_softmax = F.softmax(env_tm_model.beta.to(device) + gamma_learned, dim=-1)\n",
        "        log_likelihood = torch.mm(theta_test_softmax, beta_gamma_test_softmax)\n",
        "        N = torch.sum(test_data_word_matrix_tensor)\n",
        "        log_perplex = -torch.sum(torch.log(log_likelihood) * test_data_word_matrix_tensor) / N\n",
        "        perplexity = torch.exp(log_perplex)\n",
        "\n",
        "    return perplexity\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 73,
      "metadata": {
        "id": "oPg2aVJymm1U",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "31961664-a44c-4eae-bc1e-118d7ec38094"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Perplexity for environment 0 effects: 518.00537109375\n",
            "Test Perplexity: 636.7605590820312\n"
          ]
        }
      ],
      "source": [
        "perplexity, theta_test_softmax = evaluate_model(env_tm_model, test_data_word_matrix_tensor)\n",
        "perplexities_by_env = evaluate_model_with_gamma_per_env(env_tm_model, test_data_word_matrix_tensor)\n",
        "\n",
        "print(f'Perplexity for environment {0} effects: {perplexities_by_env}')\n",
        "\n",
        "print(f'Test Perplexity: {perplexity}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6z_Dtm4v3iSf"
      },
      "outputs": [],
      "source": [
        "def print_top_words(env_tm_model, vectorizer, num_top_words):\n",
        "    global_beta = torch.nn.functional.softmax(env_tm_model.beta, dim=1)  # Convert to probabilities\n",
        "    gamma = env_tm_model.gamma\n",
        "\n",
        "    # Print top words for global beta\n",
        "    print(\"Top words for global beta:\")\n",
        "    for i, topic in enumerate(global_beta):\n",
        "        top_words = topic.topk(num_top_words).indices\n",
        "        print(f'Topic {i+1}: {[vectorizer.get_feature_names_out()[i] for i in top_words]}')\n",
        "\n",
        "    # Print top words for gamma\n",
        "    print(\"\\nTop words for gamma:\")\n",
        "    for env_index, env_gamma in enumerate(gamma):\n",
        "        print(f\"Environment {env_index+1}:\")\n",
        "        for i, topic in enumerate(env_gamma):\n",
        "            top_words = topic.topk(num_top_words).indices\n",
        "            print(f'Topic {i+1}: {[vectorizer.get_feature_names_out()[i] for i in top_words]}')\n",
        "        print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JMNSvCWQ3nmx"
      },
      "outputs": [],
      "source": [
        "print_top_words(env_tm_model, vectorizer, num_top_words=11)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q96oqQll3qVy"
      },
      "outputs": [],
      "source": [
        "def normalize(data):\n",
        "    return (data - np.min(data)) / (np.max(data) - np.min(data))\n",
        "\n",
        "def get_top_indices_values(arr, top_n=8):\n",
        "    indices = np.argsort(-arr)[:top_n]\n",
        "    values = arr[indices]\n",
        "    return indices, values\n",
        "\n",
        "def get_words(vocabulary, indices):\n",
        "    return [vocabulary[i] for i in indices]\n",
        "\n",
        "def plot_gamma_beta_heatmaps(gamma_data, beta_data, words, title):\n",
        "    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))\n",
        "\n",
        "    # Defining the color scale between 0 and 1\n",
        "    im1 = ax1.imshow(gamma_data.T, cmap='hot', interpolation='nearest', vmin=0, vmax=1)\n",
        "    im2 = ax2.imshow(beta_data.T.reshape(-1, 1), cmap='hot', interpolation='nearest', vmin=0, vmax=1)\n",
        "\n",
        "    num_environments = gamma_data.shape[0]\n",
        "    environments = [f'Environment {i}' for i in range(num_environments)]\n",
        "\n",
        "    # Settings for gamma heatmap\n",
        "    ax1.set_yticks(np.arange(len(words)))\n",
        "    ax1.set_xticks(np.arange(num_environments))\n",
        "    ax1.set_yticklabels(words)\n",
        "    ax1.set_xticklabels(environments)\n",
        "    plt.setp(ax1.get_xticklabels(), rotation=45, ha=\"right\", rotation_mode=\"anchor\")\n",
        "    cbar1 = fig.colorbar(im1, ax=ax1)\n",
        "    cbar1.ax.set_ylabel(\"Normalized Gamma\", rotation=-90, va=\"bottom\")\n",
        "\n",
        "    # Settings for beta grid\n",
        "    ax2.set_yticks(np.arange(len(words)))\n",
        "    ax2.set_xticks([0])\n",
        "    ax2.set_yticklabels(words)\n",
        "    ax2.set_xticklabels(['Beta'])\n",
        "    plt.setp(ax2.get_xticklabels(), rotation=45, ha=\"right\", rotation_mode=\"anchor\")\n",
        "    cbar2 = fig.colorbar(im2, ax=ax2)\n",
        "    cbar2.ax.set_ylabel(\"Normalized Beta\", rotation=-90, va=\"bottom\")\n",
        "\n",
        "    ax1.set_title(title)\n",
        "    fig.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "def analyze_topic(lda, vocabulary, topic_index, top_n=8):\n",
        "    # Normalize the entire beta array for the specific topic\n",
        "    beta_values = normalize(lda.beta[topic_index, :].cpu().detach().numpy())\n",
        "\n",
        "    # Normalize the entire gamma arrays for the specific topic in all environments\n",
        "    num_environments = lda.gamma.shape[0]\n",
        "    gamma_values = [normalize(lda.gamma[i, topic_index, :].cpu().detach().numpy()) for i in range(num_environments)]\n",
        "\n",
        "    # Get the top beta indices and values\n",
        "    beta_indices, _ = get_top_indices_values(beta_values, top_n)\n",
        "\n",
        "    for env_index, gamma_value in enumerate(gamma_values):\n",
        "        # Get the top gamma indices and values\n",
        "        gamma_indices, _ = get_top_indices_values(gamma_value, top_n)\n",
        "\n",
        "        # Get the corresponding words from the vocabulary\n",
        "        gamma_words = get_words(vocabulary, gamma_indices)\n",
        "        beta_words = get_words(vocabulary, beta_indices)\n",
        "\n",
        "        # Print the top words\n",
        "        print(f\"Top words in gamma environment {env_index}:\", gamma_words)\n",
        "        print(\"Top words in beta:               \", beta_words)\n",
        "\n",
        "        # Get the gamma and beta values for top words\n",
        "        gamma_values_top_words = [gamma_values[i][gamma_indices] for i in range(num_environments)]\n",
        "        beta_values_top_words = beta_values[gamma_indices]\n",
        "\n",
        "        # Plot the heatmaps\n",
        "        plot_gamma_beta_heatmaps(np.array(gamma_values_top_words), beta_values_top_words, gamma_words, f\"Environment {env_index}: Top Words\")\n",
        "\n",
        "    # Gamma and Beta values for top words in beta\n",
        "    gamma_values_beta = [gamma_values[i][beta_indices] for i in range(num_environments)]\n",
        "    beta_values_beta = beta_values[beta_indices]\n",
        "    beta_words = get_words(vocabulary, beta_indices)\n",
        "\n",
        "    # Plot the heatmaps for the top words in beta\n",
        "    plot_gamma_beta_heatmaps(np.array(gamma_values_beta), beta_values_beta, beta_words, \"Beta: Top Words\")\n",
        "\n",
        "\n",
        "vocabulary = list(vectorizer.get_feature_names_out())\n",
        "\n",
        "# Analyzing topic 4 with 8 top words\n",
        "analyze_topic(env_tm_model, vocabulary, topic_index=10, top_n=8)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "58IWyzWk3wDS"
      },
      "outputs": [],
      "source": [
        "def analyze_gamma_per_environment(model, threshold=1e-2):\n",
        "    gamma_values = model.gamma.detach().cpu().numpy()\n",
        "\n",
        "    for env_index in range(gamma_values.shape[0]):\n",
        "        print(f\"Environment {env_index}:\")\n",
        "        gamma_env_values = gamma_values[env_index]\n",
        "\n",
        "        close_to_zero = np.abs(gamma_env_values) < threshold\n",
        "        sparsity_percentage = 100 * np.sum(close_to_zero) / gamma_env_values.size\n",
        "\n",
        "        print(f\"Sparsity Percentage: {sparsity_percentage}%\")\n",
        "        print(f\"Mean of Gamma: {np.mean(gamma_env_values)}\")\n",
        "        print(f\"Standard Deviation of Gamma: {np.std(gamma_env_values)}\")\n",
        "        plt.hist(gamma_env_values.flatten(), bins=50)\n",
        "        plt.title(f\"Histogram of Gamma Values for Environment {env_index}\")\n",
        "        plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aLcTjGDt3zSf"
      },
      "outputs": [],
      "source": [
        "analyze_gamma_per_environment(env_tm_model)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PQWueNKeif63"
      },
      "outputs": [],
      "source": [
        "\n",
        "# Step 1: Extract the topic-word distribution from EnvTM model\n",
        "with torch.no_grad():\n",
        "    beta_dist = torch.softmax(env_tm_model.beta, dim=-1)  # Apply softmax to get probabilities\n",
        "    beta_values = beta_dist.cpu().numpy()  # Convert to numpy for easy handling\n",
        "\n",
        "# Step 2: Get the top 10 words from each topic\n",
        "top_n = 10  # Number of words to take from each topic\n",
        "top_words_per_topic = []\n",
        "\n",
        "# Assuming you have a vectorizer from sklearn\n",
        "feature_names = vectorizer.get_feature_names_out()\n",
        "\n",
        "# For each topic, take the top `top_n` words based on probability\n",
        "for topic_idx in range(env_tm_model.num_topics):\n",
        "    top_word_indices = beta_values[topic_idx].argsort()[-top_n:][::-1]  # Get indices of top `n` words\n",
        "    top_words = [feature_names[word_idx] for word_idx in top_word_indices]\n",
        "    top_words_per_topic.append(top_words)\n",
        "\n",
        "# Step 3: Tokenize and filter test corpus using the training vocabulary\n",
        "words_train = set(feature_names)  # Set of words in the training vocabulary\n",
        "\n",
        "# Tokenize the test data\n",
        "corpus_test = [doc.split() for doc in npmi_test['text']]\n",
        "\n",
        "# Filter test documents to keep only words that are in the training vocabulary\n",
        "corpus_test_filtered = [[word for word in doc if word in words_train] for doc in corpus_test]\n",
        "\n",
        "# Step 4: Create a Gensim dictionary from the filtered test corpus\n",
        "dictionary = Dictionary(corpus_test_filtered)\n",
        "\n",
        "# Step 5: Calculate NPMI using the CoherenceModel\n",
        "coherence_model = CoherenceModel(\n",
        "    topics=top_words_per_topic,  # List of top words per topic from the EnvTM model\n",
        "    texts=corpus_test_filtered,  # Tokenized and filtered test documents\n",
        "    dictionary=dictionary,  # Gensim dictionary created from the filtered test corpus\n",
        "    coherence='c_npmi',  # Use NPMI\n",
        "    window_size=10\n",
        ")\n",
        "\n",
        "# Step 6: Get the NPMI score\n",
        "npmi_score = coherence_model.get_coherence()\n",
        "print(f\"NPMI Score for EnvTM model: {npmi_score}\")\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}