{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kS-pdbnl7X3X"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/anonymousindividual007/Multi-environment-Topic-Models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zlhkhCeBVIIL"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import itertools as it\n",
        "import math\n",
        "import csv\n",
        "\n",
        "import torch\n",
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch.distributions as dist\n",
        "from torch.utils.data import TensorDataset, DataLoader\n",
        "from torch.distributions import Normal, Distribution, HalfCauchy, Laplace\n",
        "\n",
        "import nltk\n",
        "nltk.download('punkt')\n",
        "from collections import Counter\n",
        "from nltk.stem import WordNetLemmatizer\n",
        "from nltk.tokenize import word_tokenize\n",
        "from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer\n",
        "from scipy.sparse import csr_matrix\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_mKIpYH11tph"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")  # checks whether a GPU is available and chooses the GPU if it is"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A4ebiLxw7i64"
      },
      "outputs": [],
      "source": [
        "file_path = \"/content/Multi-environment-Topic-Models/political_stopwords.txt\"\n",
        "\n",
        "with open(file_path, 'r') as file:\n",
        "    stopwords_list = file.readlines()\n",
        "\n",
        "all_stopwords = [word.strip() for word in stopwords_list]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GKibgGW97j9V"
      },
      "outputs": [],
      "source": [
        "class LemmaTokenizer:\n",
        "\tdef __init__(self):\n",
        "\t\tself.wnl = WordNetLemmatizer()\n",
        "\tdef __call__(self, doc):\n",
        "\t\treturn [t for t in word_tokenize(doc) if str.isalpha(t)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xfZzeSFSOX6C"
      },
      "source": [
        "To use your own data replace the file path. In the cell below represents the data is for the Political Advertisements experiment.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oSTqmTh8OYaE"
      },
      "outputs": [],
      "source": [
        "file_path = '/content/Multi-environment-Topic-Models/local_channels.csv'\n",
        "\n",
        "train_data = pd.read_csv(file_path)\n",
        "\n",
        "test1 = train_data[train_data['source'] == 'right'].sample(frac=0.2, random_state=42)\n",
        "test2 = train_data[train_data['source'] == 'left'].sample(frac=0.2, random_state=42)\n",
        "\n",
        "# Drop the sampled rows from train_data\n",
        "train_data = train_data.drop(test1.index)\n",
        "train_data = train_data.drop(test2.index)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sgZB5LJrNam1"
      },
      "source": [
        "The data in the cell below is for the ideology dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9D9Hq4zh7mRZ"
      },
      "outputs": [],
      "source": [
        "# train_data= pd.read_csv('/content/Multi-environment-Topic-Models/channels_ideology_train.csv')\n",
        "# channels_ideology_test = pd.read_csv('/content/Multi-environment-Topic-Models/channels_ideology_test.csv')\n",
        "# test1 = channels_ideology_test[channels_ideology_test['source'] == 'Republican']\n",
        "# test2 = channels_ideology_test[channels_ideology_test['source'] == 'Democratic']\n",
        "# test3 = channels_ideology_test[channels_ideology_test['source'] == 'balanced']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2S27XfCnOKei"
      },
      "outputs": [],
      "source": [
        "# Specify the path to the zip file and the name of the CSV file inside it\n",
        "# zip_file_path = '/content/Multi-environment-Topic-Models/style_train_large.csv.zip'\n",
        "# csv_file_name = 'style_train_large.csv'  # Change this if the CSV file has a different name inside the zip\n",
        "\n",
        "# # Specify the temporary directory to extract the CSV file\n",
        "# temp_dir = '/content/temp_dir'\n",
        "\n",
        "# # Create a temporary directory if it doesn't exist\n",
        "# if not os.path.exists(temp_dir):\n",
        "#     os.makedirs(temp_dir)\n",
        "\n",
        "# # Extract the CSV file\n",
        "# with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:\n",
        "#     zip_ref.extract(csv_file_name, temp_dir)\n",
        "\n",
        "# # Full path to the extracted CSV file\n",
        "# csv_file_path = os.path.join(temp_dir, csv_file_name)\n",
        "\n",
        "# # Load the CSV file into a Pandas DataFrame\n",
        "# train_data = pd.read_csv(csv_file_path, encoding='ISO-8859-1')\n",
        "# style_test_df = pd.read_csv('/content/Multi-environment-Topic-Models/style_test.csv', encoding='ISO-8859-1')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XkPjLBo-HZlZ"
      },
      "outputs": [],
      "source": [
        "env_mapping = {value: index for index, value in enumerate(train_data['source'].unique())}\n",
        "\n",
        "num_docs = len(train_data)\n",
        "num_envs = len(env_mapping)\n",
        "env_index_matrix = np.zeros((num_docs, num_envs), dtype=int)\n",
        "\n",
        "for doc_idx, source in enumerate(train_data['source']):\n",
        "    env_idx = env_mapping[source]\n",
        "    env_index_matrix[doc_idx, env_idx] = 1\n",
        "\n",
        "env_index_tensor = torch.from_numpy(env_index_matrix).float().to(device)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xkvypt5-SQ3J"
      },
      "source": [
        "Ads and ideology data preprocessing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7l0DupZwSXBm"
      },
      "outputs": [],
      "source": [
        "vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(), ngram_range=(1, 1), stop_words=all_stopwords, max_df=0.4, min_df=0.0006)\n",
        "\n",
        "docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QzbjoKQTQ9Ar"
      },
      "source": [
        "Style tokenizer IID"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BpVcT3cTQ-qn"
      },
      "outputs": [],
      "source": [
        "# vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(), ngram_range=(1, 1), stop_words=all_stopwords, max_df=0.5, min_df=0.006)\n",
        "# docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "# docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KvIWx9XhRDW1"
      },
      "source": [
        "Style tokenizer OOD data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nb0BYGruRCRa"
      },
      "outputs": [],
      "source": [
        "# vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(),\n",
        "#                              ngram_range=(1, 1),\n",
        "#                              stop_words=all_stopwords,\n",
        "#                              max_df=0.5,\n",
        "#                              min_df=0.006)\n",
        "\n",
        "# vectorizer.fit(train_data['text'])\n",
        "# docs_word_matrix_raw = vectorizer.transform(train_data['text'])\n",
        "# docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wCaQoRRBVGJm"
      },
      "outputs": [],
      "source": [
        "class EnvTM(nn.Module):\n",
        "    def __init__(self, num_topics, num_words, num_envs, device='cpu'):\n",
        "        super(EnvTM, self).__init__()\n",
        "\n",
        "        def init_param(shape):\n",
        "            return nn.Parameter(torch.randn(shape, device=device))\n",
        "\n",
        "        def init_param_zeros(shape):\n",
        "            return nn.Parameter(torch.zeros(shape, device=device))\n",
        "\n",
        "        self.num_topics, self.num_words, self.num_envs = num_topics, num_words, num_envs\n",
        "\n",
        "        # Global Beta, β_{0,k} ~ 𝒩(·,·)\n",
        "        self.beta = init_param([num_topics, num_words])\n",
        "        self.beta_logvar = init_param_zeros([num_topics, num_words])\n",
        "        self.beta_prior = Normal(torch.zeros([num_topics, num_words], device=device), torch.ones([num_topics, num_words], device=device))\n",
        "\n",
        "        # Lambda parameters, λ_{e,k} ~ Half-Cauchy(0,_)\n",
        "        self.lambda_ek = torch.distributions.HalfCauchy(scale=torch.tensor(0.4, device=device)).rsample([num_envs, num_topics])\n",
        "\n",
        "        # Expand lambda_ek to have the same shape across words\n",
        "        self.lambda_ek = self.lambda_ek.unsqueeze(-1).expand(-1, -1, num_words)\n",
        "\n",
        "        # Tau parameter, τ ~ Half-Cauchy(0, _)\n",
        "        self.tau = torch.distributions.HalfCauchy(scale=torch.tensor(0.5, device=device)).rsample()\n",
        "\n",
        "        # Gamma parameters, γ_{e,k} ~ 𝒩(0, λ_{e,k}^2 τ^2) --> hMTM\n",
        "        self.gamma = init_param_zeros([num_envs, num_topics, num_words])\n",
        "        self.gamma_logvar = init_param_zeros([num_envs, num_topics, num_words])\n",
        "        gamma_prior_variance = (self.lambda_ek ** 2) * (self.tau ** 2)\n",
        "        self.gamma_prior = Normal(torch.zeros_like(gamma_prior_variance), torch.sqrt(gamma_prior_variance).add(1e-8))\n",
        "\n",
        "        # Gamma parameters, γ_{e,k} ~ 𝒩(0, 1) --> nEATM\n",
        "        # self.gamma = init_param([num_envs, num_topics, num_words])  # Initialize with normal distribution\n",
        "        # self.gamma_logvar = init_param_zeros([num_envs, num_topics, num_words]) # Initialize log variance\n",
        "        # self.gamma_prior = Normal(torch.zeros([num_envs, num_topics, num_words], device=device), torch.ones([num_envs, num_topics, num_words], device=device))\n",
        "\n",
        "        self.theta_global_prior = Normal(torch.zeros(num_topics, device=device), torch.ones(num_topics, device=device))\n",
        "\n",
        "        self.theta_global_net = nn.Sequential(\n",
        "            nn.Linear(num_words, 50),\n",
        "            nn.BatchNorm1d(50),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(50, num_topics * 2)\n",
        "        )\n",
        "\n",
        "    def forward(self, bow, x_d):\n",
        "        batch_size, vocab_size = bow.size()\n",
        "        self.theta_global_params = self.theta_global_net(bow)\n",
        "        theta_global_mu, theta_global_logvar = self.theta_global_params.split(self.num_topics, dim=-1)\n",
        "        theta_global_logvar = theta_global_logvar.add(1e-8)\n",
        "        theta_sample = Normal(theta_global_mu, torch.exp(0.5 * theta_global_logvar).add(1e-8)).rsample()\n",
        "        theta_softmax = F.softmax(theta_sample, dim=-1)\n",
        "\n",
        "        beta_dist = Normal(self.beta, torch.exp(0.5 * self.beta_logvar).add(1e-8))\n",
        "        beta_sample = beta_dist.rsample()\n",
        "\n",
        "        gamma_dist = Normal(self.gamma, torch.exp(0.5 * self.gamma_logvar).add(1e-8))\n",
        "        gamma_sample = gamma_dist.rsample()\n",
        "        gamma_effect = torch.einsum('be,etv->btv', x_d, gamma_sample)\n",
        "\n",
        "        adjusted_beta = self.beta.unsqueeze(0) + gamma_effect\n",
        "        adjusted_beta_softmax = F.softmax(adjusted_beta, dim=-1)\n",
        "        eta_d = torch.einsum('bt,btv->bv', theta_softmax, adjusted_beta_softmax)\n",
        "\n",
        "        return eta_d\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vZaYTDeR3qdD"
      },
      "outputs": [],
      "source": [
        "def calculate_kl_divergences(EnvTM, x_d):\n",
        "    theta_global_mu, theta_global_logvar = EnvTM.theta_global_params.split(EnvTM.num_topics, dim=-1)\n",
        "    theta_global_logvar = theta_global_logvar.add(1e-8)\n",
        "    theta_global = Normal(theta_global_mu, torch.exp(0.5 * theta_global_logvar).add(1e-8))\n",
        "    theta_global_kl = torch.distributions.kl.kl_divergence(theta_global, EnvTM.theta_global_prior).sum()\n",
        "\n",
        "    beta = Normal(EnvTM.beta, torch.exp(0.5 * EnvTM.beta_logvar))\n",
        "    beta_kl = torch.distributions.kl.kl_divergence(beta, EnvTM.beta_prior).sum()\n",
        "\n",
        "    gamma = Normal(EnvTM.gamma, torch.exp(0.5 * EnvTM.gamma_logvar).add(1e-8))\n",
        "    gamma_kl = torch.distributions.kl.kl_divergence(gamma, EnvTM.gamma_prior).sum()\n",
        "\n",
        "    return theta_global_kl, beta_kl, gamma_kl"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "32MtSHiD3saH"
      },
      "outputs": [],
      "source": [
        "def bbvi_update(minibatch, env_index, EnvTM, optimizer, n_samples):\n",
        "    optimizer.zero_grad()\n",
        "    elbo_accumulator = torch.zeros(1, device=minibatch.device)\n",
        "    z = EnvTM(minibatch, env_index)\n",
        "\n",
        "    theta_global_kl, beta_kl, gamma_kl = calculate_kl_divergences(EnvTM, env_index)\n",
        "    elbo = (minibatch * z.log()).sum(-1).mul(n_samples).sub(theta_global_kl + beta_kl + gamma_kl)\n",
        "    elbo_accumulator += elbo.sum()\n",
        "\n",
        "    (-elbo_accumulator).backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    return elbo_accumulator.item()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NUaAWQtzKN_u"
      },
      "outputs": [],
      "source": [
        "def train_model(EnvTM, docs_word_matrix_tensor, env_index_tensor, num_epochs=80, minibatch_size=1024, lr=0.01):\n",
        "    EnvTM = EnvTM.to(device)\n",
        "    optimizer = torch.optim.Adam(EnvTM.parameters(), lr=lr, betas=(0.9, 0.999))\n",
        "\n",
        "    docs_word_matrix_tensor = docs_word_matrix_tensor.to(device)\n",
        "    env_index_tensor = env_index_tensor.to(device)\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        elbo_accumulator = 0.0\n",
        "        permutation = torch.randperm(docs_word_matrix_tensor.size()[0])\n",
        "\n",
        "        for i in range(0, docs_word_matrix_tensor.size()[0], minibatch_size):\n",
        "            indices = permutation[i:i+minibatch_size]\n",
        "            minibatch = docs_word_matrix_tensor[indices]\n",
        "            minibatch_env_index = env_index_tensor[indices]\n",
        "\n",
        "            elbo = bbvi_update(minibatch, minibatch_env_index, EnvTM, optimizer, docs_word_matrix_tensor.size()[0])\n",
        "            elbo_accumulator += elbo\n",
        "\n",
        "        avg_elbo = elbo_accumulator / (docs_word_matrix_tensor.size()[0] / minibatch_size)\n",
        "        print(f'Epoch: {epoch+1}, Average ELBO: {avg_elbo}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6H2e8NcS3FUs"
      },
      "outputs": [],
      "source": [
        "num_topics = 20\n",
        "num_envs = 2\n",
        "num_epoch = 150\n",
        "env_tm_model = EnvTM(num_topics=num_topics, num_words=len(vectorizer.get_feature_names_out()), num_envs=num_envs, device=device)\n",
        "\n",
        "train_model(env_tm_model, docs_word_matrix_tensor, env_index_tensor, num_epochs=num_epoch, minibatch_size=1024, lr=0.01)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W6wpUBux3HoP"
      },
      "outputs": [],
      "source": [
        "test_data_word_matrix_raw = vectorizer.transform(test1['text'])\n",
        "test_data_word_matrix_tensor = torch.from_numpy(test_data_word_matrix_raw.toarray()).float().to(device)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zDFxIWUV3Hqq"
      },
      "outputs": [],
      "source": [
        "def evaluate_model(env_tm_model, test_data_word_matrix_tensor):\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    env_tm_model.to(device)\n",
        "    env_tm_model.eval()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        theta_test_params = env_tm_model.theta_global_net(test_data_word_matrix_tensor)\n",
        "        theta_test_mu, theta_test_logvar = theta_test_params.split(env_tm_model.num_topics, dim=-1)\n",
        "        theta_test_dist = Normal(theta_test_mu, torch.exp(0.5 * theta_test_logvar).add(1e-8))\n",
        "        theta_test = theta_test_dist.rsample()\n",
        "        theta_test_softmax = F.softmax(theta_test, dim=-1)\n",
        "        beta_test_softmax = F.softmax(env_tm_model.beta.to(device), dim=-1)\n",
        "\n",
        "        likelihood = torch.mm(theta_test_softmax, beta_test_softmax)\n",
        "        N = torch.sum(test_data_word_matrix_tensor)\n",
        "        log_perplex = -torch.sum(torch.log(likelihood) * test_data_word_matrix_tensor) / N\n",
        "        perplexity = torch.exp(log_perplex)\n",
        "\n",
        "    return perplexity, theta_test_softmax\n",
        "\n",
        "def evaluate_model_with_gamma_per_env(env_tm_model, test_data_word_matrix_tensor):\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    env_tm_model.to(device)\n",
        "    env_tm_model.eval()\n",
        "\n",
        "\n",
        "    with torch.no_grad():\n",
        "        theta_test_params = env_tm_model.theta_global_net(test_data_word_matrix_tensor)\n",
        "        theta_test_mu, theta_test_logvar = theta_test_params.split(env_tm_model.num_topics, dim=-1)\n",
        "        theta_test_dist = Normal(theta_test_mu, torch.exp(0.5 * theta_test_logvar).add(1e-8))\n",
        "        theta_test = theta_test_dist.rsample()\n",
        "        theta_test_softmax = F.softmax(theta_test, dim=-1)\n",
        "\n",
        "        gamma_learned = env_tm_model.gamma[0]\n",
        "\n",
        "        beta_gamma_test_softmax = F.softmax(env_tm_model.beta.to(device) + gamma_learned, dim=-1)\n",
        "        log_likelihood = torch.mm(theta_test_softmax, beta_gamma_test_softmax)\n",
        "        N = torch.sum(test_data_word_matrix_tensor)\n",
        "        log_perplex = -torch.sum(torch.log(log_likelihood) * test_data_word_matrix_tensor) / N\n",
        "        perplexity = torch.exp(log_perplex)\n",
        "\n",
        "    return perplexity\n",
        "\n",
        "perplexity, theta_test_softmax = evaluate_model(env_tm_model, test_data_word_matrix_tensor)\n",
        "perplexities_by_env = evaluate_model_with_gamma_per_env(env_tm_model, test_data_word_matrix_tensor)\n",
        "\n",
        "print(f'Perplexity for environment {0}: {perplexities_by_env}')\n",
        "\n",
        "print(f'Test Perplexity: {perplexity}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6z_Dtm4v3iSf"
      },
      "outputs": [],
      "source": [
        "def print_top_words(env_tm_model, vectorizer, num_top_words):\n",
        "    global_beta = torch.nn.functional.softmax(env_tm_model.beta, dim=1)  # Convert to probabilities\n",
        "    gamma = env_tm_model.gamma\n",
        "\n",
        "    # Print top words for global beta\n",
        "    print(\"Top words for global beta:\")\n",
        "    for i, topic in enumerate(global_beta):\n",
        "        top_words = topic.topk(num_top_words).indices\n",
        "        print(f'Topic {i+1}: {[vectorizer.get_feature_names_out()[i] for i in top_words]}')\n",
        "\n",
        "    # Print top words for gamma\n",
        "    print(\"\\nTop words for gamma:\")\n",
        "    for env_index, env_gamma in enumerate(gamma):\n",
        "        print(f\"Environment {env_index+1}:\")\n",
        "        for i, topic in enumerate(env_gamma):\n",
        "            top_words = topic.topk(num_top_words).indices\n",
        "            print(f'Topic {i+1}: {[vectorizer.get_feature_names_out()[i] for i in top_words]}')\n",
        "        print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JMNSvCWQ3nmx"
      },
      "outputs": [],
      "source": [
        "print_top_words(env_tm_model, vectorizer, num_top_words=12)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R839DagzHRO1"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
