{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment line below to install exlib\n",
    "# !pip install exlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tqdm\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from datasets import load_dataset\n",
    "import sentence_transformers\n",
    "\n",
    "import sys\n",
    "sys.path.append('../../src')\n",
    "import exlib\n",
    "from exlib.datasets.multilingual_politeness import PolitenessDataset, PolitenessClassifier, PolitenessFixScore, get_politeness_scores\n",
    "from exlib.datasets.politeness_helper import load_lexica\n",
    "\n",
    "from exlib.features.text import *\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load datasets and pre-trained model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample inference on dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at anonymized-model and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "  0%|                                                                                          | 0/143 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Text: The intro mentions the ISO 8601 international standard adopted in most western countries. What does this even mean? Who are we suggesting has done the adoption?\n",
      "Politeness: 0.21642041206359863\n",
      "\n",
      "Text: I'm a user on PrettyCure.org, and somebody on the site said they are making a fourth season of PreCure. It's a rumuor, but is it true? That person said it's more like Tokyo Mew Mew, a group of girls.\n",
      "Politeness: 0.21494880318641663\n",
      "\n",
      "Text: Hello fellow Wikipedians, I have just added archive links to on Essen. Please take a moment to review my edit. If necessary, add after the link to keep me from modifying it.\n",
      "Politeness: 0.12163643538951874\n",
      "\n",
      "Text: I saw the template citing this issue and since there was no section here discussing it I've decided to start one. I'm a Canadian and most of our television programs are also aired in the US so my knowledge of what's on TV outside of North America is limited. So I'm not sure of how much help I can be, but I do have some ideas on how to improve this section and I'm open to feedback.\n",
      "Politeness: 0.09295740723609924\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(1234)\n",
    "\n",
    "dataset = PolitenessDataset(\"test\")\n",
    "dataloader = DataLoader(dataset, batch_size=4, shuffle=False)\n",
    "model = PolitenessClassifier()\n",
    "model.to(device)\n",
    "model.eval()\n",
    "\n",
    "for batch in tqdm(dataloader): \n",
    "    input_ids = batch['input_ids'].to(device)\n",
    "    attention_mask = batch['attention_mask'].to(device)\n",
    "    output = model(input_ids, attention_mask)\n",
    "    utterances = [dataset.tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]\n",
    "    for utterance, label in zip(utterances, output):\n",
    "        print(\"Text: {}\\nPoliteness: {}\\n\".format(utterance, label.item()))\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at anonymized-model and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- word Level Groups ----\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 143/143 [01:22<00:00,  1.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- phrase Level Groups ----\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 143/143 [02:05<00:00,  1.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- sentence Level Groups ----\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 143/143 [01:53<00:00,  1.27it/s]\n"
     ]
    }
   ],
   "source": [
    "all_baselines_scores = get_politeness_scores(baselines = ['word', 'phrase', 'sentence'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BASELINE word mean score: 0.6839182740178517\n",
      "BASELINE phrase mean score: 0.6350535143089757\n",
      "BASELINE sentence mean score: 0.6108726882043238\n"
     ]
    }
   ],
   "source": [
    "for name in all_baselines_scores:\n",
    "    metric = torch.tensor(all_baselines_scores[name])\n",
    "    mean_metric = metric.nanmean()\n",
    "    print(f'BASELINE {name} mean score: {mean_metric}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BERTopic (Clustering)\n",
    "\n",
    "create topics from the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first, save all the utterances\n",
    "\n",
    "dataset = PolitenessDataset(\"test\")\n",
    "utterances = [' '.join(dataset[i]['word_list']) for i in range(len(dataset))]\n",
    "# torch.save(utterances, '../../fix/utterances/multilingual_politeness_test.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# then, install bertopic + use them on the utterances\n",
    "\n",
    "!pip install bertopic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at anonymized-model and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- clustering Level Groups ----\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 143/143 [00:17<00:00,  8.06it/s]\n"
     ]
    }
   ],
   "source": [
    "all_baselines_scores = get_politeness_scores(baselines = ['clustering'])\n",
    "\n",
    "# make sure utterances_path is set to where utteraces is saved in your directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BASELINE clustering mean score: 0.6679689280296627\n"
     ]
    }
   ],
   "source": [
    "for name, score in all_baselines_scores.items():\n",
    "    print(f'BASELINE {name} mean score: {score.mean()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
