{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/anonymousindividual007/Multi-environment-Topic-Models"
      ],
      "metadata": {
        "id": "NaJ2quyikcTy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install torch pyro-ppl"
      ],
      "metadata": {
        "id": "cLb7g81cgoiR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PCJkael8gJ9I"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "from sklearn.datasets import fetch_20newsgroups\n",
        "from sklearn.feature_extraction.text import CountVectorizer\n",
        "import nltk\n",
        "nltk.download('punkt')\n",
        "from collections import Counter\n",
        "from nltk.stem import WordNetLemmatizer\n",
        "from nltk.tokenize import word_tokenize\n",
        "from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer\n",
        "from scipy.sparse import csr_matrix\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import os\n",
        "import pyro\n",
        "import pyro.distributions as dist\n",
        "from pyro.infer import MCMC, NUTS\n",
        "import torch\n",
        "import math\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from pyro.infer import SVI, TraceMeanField_ELBO\n",
        "from tqdm import trange"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n"
      ],
      "metadata": {
        "id": "ZPdVqPt2lutS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "smoke_test = 'CI' in os.environ\n"
      ],
      "metadata": {
        "id": "2KIE8189g9_Z"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "file_path = \"/content/Env_Adjusted_TM/data/political_stopwords.txt\"\n",
        "\n",
        "with open(file_path, 'r') as file:\n",
        "    stopwords_list = file.readlines()\n"
      ],
      "metadata": {
        "id": "vn2u_IJjkhvz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "all_stopwords = [word.strip() for word in stopwords_list]"
      ],
      "metadata": {
        "id": "sUm5R5EjkkUq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "len(all_stopwords)"
      ],
      "metadata": {
        "id": "yFWm6m9XkmyL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class LemmaTokenizer:\n",
        "\tdef __init__(self):\n",
        "\t\tself.wnl = WordNetLemmatizer()\n",
        "\tdef __call__(self, doc):\n",
        "\t\treturn [t for t in word_tokenize(doc) if str.isalpha(t)]"
      ],
      "metadata": {
        "id": "rOEqoaQ5kovJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "file_path = '/content/Multi-environment-Topic-Models/local_channels.csv'\n",
        "\n",
        "train_data = pd.read_csv(file_path)\n",
        "\n",
        "test1 = train_data[train_data['source'] == 'right'].sample(frac=0.2, random_state=42)\n",
        "test2 = train_data[train_data['source'] == 'left'].sample(frac=0.2, random_state=42)\n",
        "\n",
        "# Drop the sampled rows from train_data\n",
        "train_data = train_data.drop(test1.index)\n",
        "train_data = train_data.drop(test2.index)"
      ],
      "metadata": {
        "id": "1WKtSuB8poei"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# train_data= pd.read_csv('/content/Multi-environment-Topic-Models/channels_ideology_train.csv')\n",
        "# channels_ideology_test = pd.read_csv('/content/Multi-environment-Topic-Models/channels_ideology_test.csv')\n",
        "# test1 = channels_ideology_test[channels_ideology_test['source'] == 'Republican']\n",
        "# test2 = channels_ideology_test[channels_ideology_test['source'] == 'Democratic']\n",
        "# test3 = channels_ideology_test[channels_ideology_test['source'] == 'balanced']"
      ],
      "metadata": {
        "id": "iNMxRsDOpqIM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# Specify the path to the zip file and the name of the CSV file inside it\n",
        "# zip_file_path = '/content/Multi-environment-Topic-Models/style_train_large.csv.zip'\n",
        "# csv_file_name = 'style_train_large.csv'  # Change this if the CSV file has a different name inside the zip\n",
        "\n",
        "# # Specify the temporary directory to extract the CSV file\n",
        "# temp_dir = '/content/temp_dir'\n",
        "\n",
        "# # Create a temporary directory if it doesn't exist\n",
        "# if not os.path.exists(temp_dir):\n",
        "#     os.makedirs(temp_dir)\n",
        "\n",
        "# # Extract the CSV file\n",
        "# with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:\n",
        "#     zip_ref.extract(csv_file_name, temp_dir)\n",
        "\n",
        "# # Full path to the extracted CSV file\n",
        "# csv_file_path = os.path.join(temp_dir, csv_file_name)\n",
        "\n",
        "# # Load the CSV file into a Pandas DataFrame\n",
        "# train_data = pd.read_csv(csv_file_path, encoding='ISO-8859-1')\n",
        "# style_test_df = pd.read_csv('/content/Multi-environment-Topic-Models/style_test.csv', encoding='ISO-8859-1')"
      ],
      "metadata": {
        "id": "cLjv_sB1prwC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "\n",
        "vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(), ngram_range=(1, 1), stop_words=all_stopwords, max_df=0.4, min_df=0.0006)\n",
        "\n",
        "docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)"
      ],
      "metadata": {
        "id": "L_14FlGaptdO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "#style tok iid\n",
        "# vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(), ngram_range=(1, 1), stop_words=all_stopwords, max_df=0.5, min_df=0.006)\n",
        "\n",
        "\n",
        "# docs_word_matrix_raw = vectorizer.fit_transform(train_data['text'])\n",
        "# docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)"
      ],
      "metadata": {
        "id": "hYnU0xKspu38"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# vectorizer = CountVectorizer(tokenizer=LemmaTokenizer(),\n",
        "#                              ngram_range=(1, 1),\n",
        "#                              stop_words=all_stopwords,\n",
        "#                              max_df=0.5,\n",
        "#                              min_df=0.006)\n",
        "\n",
        "# vectorizer.fit(train_data['text'])\n",
        "\n",
        "# docs_word_matrix_raw = vectorizer.transform(train_data['text'])\n",
        "\n",
        "# env_mapping = {value: index for index, value in enumerate(train_data['source'].unique())}\n",
        "# env_index = train_data['source'].apply(lambda x: env_mapping[x])\n",
        "\n",
        "# docs_word_matrix_tensor = torch.from_numpy(docs_word_matrix_raw.toarray()).float().to(device)\n",
        "# env_index_tensor = torch.from_numpy(env_index.to_numpy()).long().to(device)\n"
      ],
      "metadata": {
        "id": "XirAsvwFpwXK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "vocab = pd.DataFrame(columns=['word', 'index'])\n",
        "vocab['word'] = vectorizer.get_feature_names_out()\n",
        "vocab['index'] = vocab.index"
      ],
      "metadata": {
        "id": "SzQjDOPGgPPf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print('Dictionary size: %d' % len(vocab))\n",
        "print('Corpus size: {}'.format(docs.shape))"
      ],
      "metadata": {
        "id": "3ZIyL7w8gQy8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class Encoder(nn.Module):\n",
        "    # Base class for the encoder net, used in the guide\n",
        "    def __init__(self, vocab_size, num_topics, hidden, dropout):\n",
        "        super().__init__()\n",
        "        self.drop = nn.Dropout(dropout)  # to avoid component collapse\n",
        "        self.fc1 = nn.Linear(vocab_size, hidden)\n",
        "        self.fc2 = nn.Linear(hidden, hidden)\n",
        "        self.fcmu = nn.Linear(hidden, num_topics)\n",
        "        self.fclv = nn.Linear(hidden, num_topics)\n",
        "        # NB: here we set `affine=False` to reduce the number of learning parameters\n",
        "        # See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html\n",
        "        # for the effect of this flag in BatchNorm1d\n",
        "        self.bnmu = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse\n",
        "        self.bnlv = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse\n",
        "\n",
        "    def forward(self, inputs):\n",
        "        h = F.softplus(self.fc1(inputs))\n",
        "        h = F.softplus(self.fc2(h))\n",
        "        h = self.drop(h)\n",
        "        # μ and Σ are the outputs\n",
        "        logtheta_loc = self.bnmu(self.fcmu(h))\n",
        "        logtheta_logvar = self.bnlv(self.fclv(h))\n",
        "        logtheta_scale = (0.5 * logtheta_logvar).exp()  # Enforces positivity\n",
        "        return logtheta_loc, logtheta_scale\n",
        "\n",
        "\n",
        "class Decoder(nn.Module):\n",
        "    # Base class for the decoder net, used in the model\n",
        "    def __init__(self, vocab_size, num_topics, dropout):\n",
        "        super().__init__()\n",
        "        self.beta = nn.Linear(num_topics, vocab_size, bias=False)\n",
        "        self.bn = nn.BatchNorm1d(vocab_size, affine=False)\n",
        "        self.drop = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, inputs):\n",
        "        inputs = self.drop(inputs)\n",
        "        # the output is σ(βθ)\n",
        "        return F.softmax(self.bn(self.beta(inputs)), dim=1)\n",
        "\n",
        "\n",
        "class ProdLDA(nn.Module):\n",
        "    def __init__(self, vocab_size, num_topics, hidden, dropout):\n",
        "        super().__init__()\n",
        "        self.vocab_size = vocab_size\n",
        "        self.num_topics = num_topics\n",
        "        self.encoder = Encoder(vocab_size, num_topics, hidden, dropout)\n",
        "        self.decoder = Decoder(vocab_size, num_topics, dropout)\n",
        "\n",
        "    def model(self, docs):\n",
        "        pyro.module(\"decoder\", self.decoder)\n",
        "        with pyro.plate(\"documents\", docs.shape[0]):\n",
        "            # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a logistic-normal distribution\n",
        "            logtheta_loc = docs.new_zeros((docs.shape[0], self.num_topics))\n",
        "            logtheta_scale = docs.new_ones((docs.shape[0], self.num_topics))\n",
        "            logtheta = pyro.sample(\n",
        "                \"logtheta\", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))\n",
        "            theta = F.softmax(logtheta, -1)\n",
        "\n",
        "            # conditional distribution of 𝑤𝑛 is defined as\n",
        "            # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃))\n",
        "            count_param = self.decoder(theta)\n",
        "            # Currently, PyTorch Multinomial requires `total_count` to be homogeneous.\n",
        "            # Because the numbers of words across documents can vary,\n",
        "            # we will use the maximum count accross documents here.\n",
        "            # This does not affect the result because Multinomial.log_prob does\n",
        "            # not require `total_count` to evaluate the log probability.\n",
        "            total_count = int(docs.sum(-1).max())\n",
        "            pyro.sample(\n",
        "                'obs',\n",
        "                dist.Multinomial(total_count, count_param),\n",
        "                obs=docs\n",
        "            )\n",
        "\n",
        "    def guide(self, docs):\n",
        "        pyro.module(\"encoder\", self.encoder)\n",
        "        with pyro.plate(\"documents\", docs.shape[0]):\n",
        "            # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a logistic-normal distribution,\n",
        "            # where μ and Σ are the encoder network outputs\n",
        "            logtheta_loc, logtheta_scale = self.encoder(docs)\n",
        "            logtheta = pyro.sample(\n",
        "                \"logtheta\", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))\n",
        "\n",
        "    def beta(self):\n",
        "        # beta matrix elements are the weights of the FC layer on the decoder\n",
        "        return self.decoder.beta.weight.cpu().detach().T"
      ],
      "metadata": {
        "id": "kUzdRpalgULL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# setting global variables\n",
        "seed = 0\n",
        "torch.manual_seed(seed)\n",
        "pyro.set_rng_seed(seed)\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "num_topics = 20 if not smoke_test else 3\n",
        "docs = docs.float().to(device)\n",
        "batch_size = 32\n",
        "learning_rate = 1e-3\n",
        "num_epochs = 50 if not smoke_test else 1"
      ],
      "metadata": {
        "id": "fasf0aClgW-R"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# training\n",
        "pyro.clear_param_store()\n",
        "\n",
        "prodLDA = ProdLDA(\n",
        "    vocab_size=docs.shape[1],\n",
        "    num_topics=num_topics,\n",
        "    hidden=100 if not smoke_test else 10,\n",
        "    dropout=0.2\n",
        ")\n",
        "prodLDA.to(device)\n",
        "\n",
        "optimizer = pyro.optim.Adam({\"lr\": learning_rate})\n",
        "svi = SVI(prodLDA.model, prodLDA.guide, optimizer, loss=TraceMeanField_ELBO())\n",
        "num_batches = int(math.ceil(docs.shape[0] / batch_size)) if not smoke_test else 1\n",
        "\n",
        "bar = trange(num_epochs)\n",
        "for epoch in bar:\n",
        "    running_loss = 0.0\n",
        "    for i in range(num_batches):\n",
        "        batch_docs = docs[i * batch_size:(i + 1) * batch_size, :]\n",
        "        loss = svi.step(batch_docs)\n",
        "        running_loss += loss / batch_docs.size(0)\n",
        "\n",
        "    bar.set_postfix(epoch_loss='{:.2e}'.format(running_loss))"
      ],
      "metadata": {
        "id": "-rGa_2_fgvKt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "test_data_word_matrix_raw = vectorizer.transform(test4['text'])\n",
        "test_data_word_matrix_tensor = torch.from_numpy(test_data_word_matrix_raw.toarray()).float().to(device)\n"
      ],
      "metadata": {
        "id": "kY5Ra_n0rrfZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "test_data_word_matrix_tensor.shape"
      ],
      "metadata": {
        "id": "9ixGZMO1wCEu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.distributions import Normal\n",
        "\n",
        "def calculate_perplexity(model, docs):\n",
        "    \"\"\"\n",
        "    Calculate the perplexity of the ProdLDA model on the given documents.\n",
        "\n",
        "    Args:\n",
        "        model (ProdLDA): The trained ProdLDA model.\n",
        "        docs (torch.Tensor): Tensor of documents (word counts) with shape [num_docs, vocab_size].\n",
        "\n",
        "    Returns:\n",
        "        float: The perplexity of the model on the documents.\n",
        "    \"\"\"\n",
        "    # Put the model in evaluation mode\n",
        "    model.eval()\n",
        "\n",
        "    # No need for gradients\n",
        "    with torch.no_grad():\n",
        "        # Use the guide (encoder) to get the parameters of the logistic-normal distribution\n",
        "        logtheta_loc, logtheta_scale = model.encoder(docs)\n",
        "\n",
        "        # Draw samples from the logistic-normal distribution\n",
        "        logtheta = Normal(logtheta_loc, logtheta_scale).sample()\n",
        "        theta = torch.softmax(logtheta, -1)\n",
        "\n",
        "        # Decode the samples to get the word distribution\n",
        "        count_param = model.decoder(theta)\n",
        "\n",
        "        # Compute log likelihood of the entire dataset given the word distribution\n",
        "        log_likelihood = torch.sum(docs * (count_param + 1e-10).log())\n",
        "\n",
        "        # Compute the total number of words in the dataset\n",
        "        total_word_count = torch.sum(docs)\n",
        "\n",
        "        # Handle the case where the total_word_count is zero\n",
        "        if total_word_count == 0:\n",
        "            raise ValueError(\"The total word count in the documents is zero. Cannot compute perplexity.\")\n",
        "\n",
        "        # Compute the average negative log likelihood\n",
        "        avg_neg_log_likelihood = -log_likelihood / total_word_count\n",
        "\n",
        "        # Compute perplexity\n",
        "        perplexity = torch.exp(avg_neg_log_likelihood).item()\n",
        "\n",
        "    return perplexity\n"
      ],
      "metadata": {
        "id": "aYOJleEKwbWh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "perplexity = calculate_perplexity(prodLDA, test_data_word_matrix_tensor)\n",
        "print(f'Perplexity: {perplexity}')\n"
      ],
      "metadata": {
        "id": "8S7TsfLYz8Kl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "-wt-ki6F1Noo"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}