{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kS-pdbnl7X3X"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/anonymousindividual007/Multi-environment-Topic-Models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zlhkhCeBVIIL"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import math\n",
        "import csv\n",
        "import itertools as it\n",
        "\n",
        "import torch\n",
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch.distributions as dist\n",
        "from torch.utils.data import TensorDataset, DataLoader\n",
        "from torch.distributions import Normal, Distribution, HalfCauchy, Laplace\n",
        "\n",
        "import nltk\n",
        "nltk.download('punkt')\n",
        "from collections import Counter\n",
        "from nltk.stem import WordNetLemmatizer\n",
        "from nltk.tokenize import word_tokenize\n",
        "from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer\n",
        "from scipy.sparse import csr_matrix\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jnm3xYhm1iNm"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A4ebiLxw7i64"
      },
      "outputs": [],
      "source": [
        "file_path = \"/content/Multi-environment-Topic-Models/political_stopwords.txt\"\n",
        "\n",
        "with open(file_path, 'r') as file:\n",
        "    stopwords_list = file.readlines()\n",
        "\n",
        "all_stopwords = [word.strip() for word in stopwords_list]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GKibgGW97j9V"
      },
      "outputs": [],
      "source": [
        "class LemmaTokenizer:\n",
        "\tdef __init__(self):\n",
        "\t\tself.wnl = WordNetLemmatizer()\n",
        "\tdef __call__(self, doc):\n",
        "\t\treturn [t for t in word_tokenize(doc) if str.isalpha(t)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xfZzeSFSOX6C"
      },
      "source": [
        "To use your own data replace the file path. In the cell below represents the data is for the Political Advertisements experiment.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oSTqmTh8OYaE"
      },
      "outputs": [],
      "source": [
        "file_path = '/content/Multi-environment-Topic-Models/local_channels.csv'\n",
        "\n",
        "train_data = pd.read_csv(file_path)\n",
        "\n",
        "test1 = train_data[train_data['source'] == 'right'].sample(frac=0.2, random_state=42)\n",
        "test2 = train_data[train_data['source'] == 'left'].sample(frac=0.2, random_state=42)\n",
        "\n",
        "# Drop the sampled rows from train_data\n",
        "train_data = train_data.drop(test1.index)\n",
        "train_data = train_data.drop(test2.index)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sgZB5LJrNam1"
      },
      "source": [
        "The data in the cell below is for the ideology dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9D9Hq4zh7mRZ"
      },
      "outputs": [],
      "source": [
        "# train_data= pd.read_csv('/content/Multi-environment-Topic-Models/channels_ideology_train.csv')\n",
        "# channels_ideology_test = pd.read_csv('/content/Multi-environment-Topic-Models/channels_ideology_test.csv')\n",
        "# test1 = channels_ideology_test[channels_ideology_test['source'] == 'Republican']\n",
        "# test2 = channels_ideology_test[channels_ideology_test['source'] == 'Democratic']\n",
        "# test3 = channels_ideology_test[channels_ideology_test['source'] == 'balanced']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "soetY3cuOJfd"
      },
      "source": [
        "The code below represents the preprocessing for the Style dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2S27XfCnOKei"
      },
      "outputs": [],
      "source": [
        "# Specify the path to the zip file and the name of the CSV file inside it\n",
        "# zip_file_path = '/content/Multi-environment-Topic-Models/style_train_large.csv.zip'\n",
        "# csv_file_name = 'style_train_large.csv'  # Change this if the CSV file has a different name inside the zip\n",
        "\n",
        "# # Specify the temporary directory to extract the CSV file\n",
        "# temp_dir = '/content/temp_dir'\n",
        "\n",
        "# # Create a temporary directory if it doesn't exist\n",
        "# if not os.path.exists(temp_dir):\n",
        "#     os.makedirs(temp_dir)\n",
        "\n",
        "# # Extract the CSV file\n",
        "# with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:\n",
        "#     zip_ref.extract(csv_file_name, temp_dir)\n",
        "\n",
        "# # Full path to the extracted CSV file\n",
        "# csv_file_path = os.path.join(temp_dir, csv_file_name)\n",
        "\n",
        "# # Load the CSV file into a Pandas DataFrame\n",
        "# train_data = pd.read_csv(csv_file_path, encoding='ISO-8859-1')\n",
        "# style_test_df = pd.read_csv('/content/Multi-environment-Topic-Models/style_test.csv', encoding='ISO-8859-1')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dT2EMyjqORRd"
      },
      "source": [
        "Preprocessing the ideology and channels dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lHQBcKHHOR4O"
      },
      "outputs": [],
      "source": [
        "vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(), ngram_range=(1, 1), stop_words=all_stopwords, max_df=0.4, min_df=0.0006)\n",
        "\n",
        "docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cwekT4Q5OgfF"
      },
      "source": [
        "The code below represents the preprocessing for the iid Style experiment.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9E_AjFxLOgpL"
      },
      "outputs": [],
      "source": [
        "#style tok iid\n",
        "# vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(), ngram_range=(1, 1), stop_words=all_stopwords, max_df=0.5, min_df=0.006)\n",
        "\n",
        "\n",
        "# docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "# docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aUXS3ZABOuqX"
      },
      "source": [
        "On comment the code below for preprocessing OOD data\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DZW7QtvPOvHc"
      },
      "outputs": [],
      "source": [
        "# vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(),\n",
        "#                              ngram_range=(1, 1),\n",
        "#                              stop_words=all_stopwords,\n",
        "#                              max_df=0.5,\n",
        "#                              min_df=0.006)\n",
        "\n",
        "# vectorizer.fit(train_data['text'])\n",
        "\n",
        "# docs_word_matrix_raw = vectorizer.transform(train_data['text'])\n",
        "\n",
        "# env_mapping = {value: index for index, value in enumerate(train_data['source'].unique())}\n",
        "# env_index = train_data['source'].apply(lambda x: env_mapping[x])\n",
        "\n",
        "# docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)\n",
        "# env_index_tensor = torch.from_numpy(env_index.to_numpy()).long().to(device)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wCaQoRRBVGJm"
      },
      "outputs": [],
      "source": [
        "class LDA(nn.Module):\n",
        "    def __init__(self, num_topics, num_words, hidden_dim, device='cpu'):\n",
        "        super(LDA, self).__init__()\n",
        "\n",
        "        self.num_topics = num_topics\n",
        "        self.num_words = num_words\n",
        "\n",
        "        self.beta = nn.Parameter(torch.randn([num_topics, num_words], device=device))\n",
        "        self.beta_logvar = nn.Parameter(torch.zeros([num_topics, num_words], device=device))\n",
        "\n",
        "        self.beta_prior = Normal(torch.zeros(num_topics, num_words, device=device), torch.ones(num_topics, num_words, device=device))\n",
        "        self.theta_prior = Normal(torch.zeros(num_topics, device=device), torch.ones(num_topics, device=device))\n",
        "\n",
        "        self.theta_network = nn.Sequential(\n",
        "            nn.Linear(num_words, 50),\n",
        "            nn.BatchNorm1d(50),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(50, num_topics * 2)\n",
        "        )\n",
        "\n",
        "\n",
        "    def forward(self, bow):\n",
        "        batch_size, vocab_size = bow.size()\n",
        "\n",
        "        theta_params = self.theta_network(bow)\n",
        "        theta_mu, theta_logvar = theta_params.split(self.num_topics, dim=-1)\n",
        "        theta = Normal(theta_mu, torch.exp(0.5 * theta_logvar).add(1e-8)).rsample()\n",
        "\n",
        "        beta_dist = Normal(self.beta, torch.exp(0.5 * self.beta_logvar).add(1e-8))\n",
        "        beta_sample = beta_dist.rsample()\n",
        "\n",
        "        theta_softmax = F.softmax(theta, dim=-1)\n",
        "        beta_softmax = F.softmax(beta_sample, dim=-1)\n",
        "\n",
        "        z = theta_softmax @ beta_softmax\n",
        "        return z"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Stz1oW0Dt2v7"
      },
      "outputs": [],
      "source": [
        "def calculate_kl_divergences(model, minibatch):\n",
        "    # Compute ELBO\n",
        "    z = model(minibatch)\n",
        "    theta_mu, theta_logvar = model.theta_network(minibatch).split(model.num_topics, dim=-1)\n",
        "    theta = Normal(theta_mu, torch.exp(0.5 * theta_logvar))\n",
        "    beta = Normal(model.beta, torch.exp(0.5 * model.beta_logvar))\n",
        "\n",
        "    theta_kl = torch.distributions.kl.kl_divergence(theta, model.theta_prior).sum()\n",
        "    beta_kl = torch.distributions.kl.kl_divergence(beta, model.beta_prior).sum()\n",
        "\n",
        "    return theta_kl, beta_kl"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7_ZKYRzDvVP7"
      },
      "outputs": [],
      "source": [
        "def bbvi_update(minibatch, lda_model, optimizer, total_samples):\n",
        "    optimizer.zero_grad()\n",
        "    elbo_accumulator = torch.zeros(1, device=minibatch.device)\n",
        "\n",
        "    z = lda_model(minibatch)\n",
        "    theta_kl, beta_kl = calculate_kl_divergences(lda_model, minibatch)\n",
        "\n",
        "    elbo = (minibatch * z.log()).sum() - theta_kl - beta_kl\n",
        "    elbo_accumulator += elbo.sum()\n",
        "    (-elbo_accumulator).backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    return elbo_accumulator.item()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vZaYTDeR3qdD"
      },
      "outputs": [],
      "source": [
        "def train_model(lda_model, docs_word_matrix_tensor, num_epochs=150, minibatch_size=1024, lr=0.01, device='cpu'):\n",
        "    lda = lda_model.to(device)\n",
        "    optimizer = torch.optim.Adam(lda_model.parameters(), lr=lr, betas=(0.9, 0.999))\n",
        "\n",
        "\n",
        "    docs_word_matrix_tensor = docs_word_matrix_tensor.to(device)\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        elbo_accumulator = 0.0\n",
        "        permutation = torch.randperm(docs_word_matrix_tensor.size()[0])\n",
        "\n",
        "        for i in range(0, docs_word_matrix_tensor.size()[0], minibatch_size):\n",
        "            indices = permutation[i:i+minibatch_size]\n",
        "            minibatch = docs_word_matrix_tensor[indices]\n",
        "\n",
        "            elbo = bbvi_update(minibatch, lda, optimizer, docs_word_matrix_tensor.size()[0])\n",
        "            elbo_accumulator += elbo\n",
        "\n",
        "        avg_elbo = elbo_accumulator / (docs_word_matrix_tensor.size()[0] / minibatch_size)\n",
        "        print(f'Epoch: {epoch+1}, Average ELBO: {avg_elbo}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O881YREcr4Kl"
      },
      "outputs": [],
      "source": [
        "num_topics = 20\n",
        "num_epoch = 150\n",
        "hidden = 50\n",
        "\n",
        "lda_model = LDA(num_topics=num_topics, num_words=len(vectorizer.get_feature_names_out()), hidden_dim = hidden, device=device)\n",
        "\n",
        "train_model(lda_model, docs_word_matrix_tensor, num_epochs=num_epoch, minibatch_size=1024, lr=0.01, device=device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I5hsw_turyKL"
      },
      "outputs": [],
      "source": [
        "test_data_word_matrix_raw = vectorizer.transform(test1['text'])\n",
        "test_data_word_matrix_tensor = torch.from_numpy(test_data_word_matrix_raw.toarray()).float().to(device)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_ztlJhJIVpcZ"
      },
      "outputs": [],
      "source": [
        "def evaluate_model(env_tm_model, test_data_word_matrix_tensor):\n",
        "    lda_model.to(device)\n",
        "    lda_model.eval()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        theta_test_params = lda_model.theta_network(test_data_word_matrix_tensor)\n",
        "        theta_test_mu, theta_test_logvar = theta_test_params.split(lda_model.num_topics, dim=-1)\n",
        "        theta_test_dist = Normal(theta_test_mu, torch.exp(0.5 * theta_test_logvar).add(1e-8))\n",
        "        theta_test = theta_test_dist.rsample()\n",
        "        theta_test_softmax = F.softmax(theta_test, dim=-1)\n",
        "\n",
        "        beta_test_softmax = F.softmax(lda_model.beta.to(device), dim=-1)\n",
        "        likelihood = torch.mm(theta_test_softmax, beta_test_softmax)\n",
        "\n",
        "        # Compute perplexity of the test data\n",
        "        N = torch.sum(test_data_word_matrix_tensor)\n",
        "        log_perplex = -torch.sum(torch.log(likelihood) * test_data_word_matrix_tensor) / N\n",
        "        perplexity = torch.exp(log_perplex)\n",
        "        return perplexity"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uYKanpD7jKSP"
      },
      "outputs": [],
      "source": [
        "perplexity = evaluate_model(lda_model, test_data_word_matrix_tensor)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IwxDsgyi3jbi"
      },
      "outputs": [],
      "source": [
        "print(perplexity)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DEZNTHta3wie"
      },
      "outputs": [],
      "source": [
        "def print_top_words(lda, vectorizer, num_top_words):\n",
        "    beta = torch.nn.functional.softmax(lda.beta, dim=1)  # Convert to probabilities\n",
        "    for i, topic in enumerate(beta):\n",
        "        top_words = topic.topk(num_top_words).indices\n",
        "        print(f'Topic {i+1}: {[vectorizer.get_feature_names_out()[i] for i in top_words]}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mKrLkg_L6r5y"
      },
      "outputs": [],
      "source": [
        "print_top_words(lda_model, vectorizer, num_top_words=12)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5iqhcSBJ0zov"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
