{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment line below to install exlib\n",
    "# !pip install exlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "import numpy as np\n",
    "import tqdm\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import DataLoader\n",
    "from datasets import load_dataset\n",
    "import torch.nn as nn\n",
    "import sentence_transformers\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, \"../../src\")\n",
    "import exlib\n",
    "from exlib.datasets.emotion_helper import project_points_onto_axes, load_emotions\n",
    "from exlib.datasets.emotion import load_data, load_model, EmotionDataset, EmotionClassifier, EmotionFixScore, get_emotion_scores\n",
    "\n",
    "from exlib.features.text import *\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load datasets and pre-trained models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SamLowe/roberta-base-go_emotions\n"
     ]
    }
   ],
   "source": [
    "dataset = EmotionDataset(\"test\")\n",
    "dataloader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "model = EmotionClassifier().eval().to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/2714 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Text: I’m really sorry about your situation :( Although I love the names Sapphira, Cirilla, and Scarlett!\n",
      "Emotion: remorse\n",
      "\n",
      "Text: It's wonderful because it's awful. At not with.\n",
      "Emotion: admiration\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for batch in tqdm(dataloader): \n",
    "    input_ids = batch['input_ids'].to(device)\n",
    "    attention_mask = batch['attention_mask'].to(device)\n",
    "    output = model(input_ids, attention_mask)\n",
    "    utterances = [dataset.tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]\n",
    "    for utterance, label in zip(utterances, output.logits):\n",
    "        id_str = model.model.config.id2label[label.argmax().item()]\n",
    "        print(\"Text: {}\\nEmotion: {}\\n\".format(utterance, id_str))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SamLowe/roberta-base-go_emotions\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [00:43<00:00, 31.44it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [02:00<00:00, 11.29it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [01:21<00:00, 16.62it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [01:34<00:00, 14.43it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [00:47<00:00, 28.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [02:59<00:00,  7.57it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [18:36<00:00,  1.21it/s]\n"
     ]
    }
   ],
   "source": [
    "all_baseline_scores = get_emotion_scores([\n",
    "    \"identity\", \"random\", \"word\", \"phrase\", \"sentence\", \"clustering\", \"archipelago\"    \n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BASELINE identity mean score: 0.010318498686651098\n",
      "BASELINE random mean score: 0.030460640761845705\n",
      "BASELINE word mean score: 0.11819195071168308\n",
      "BASELINE phrase mean score: 0.019752760732233695\n",
      "BASELINE sentence mean score: 0.0119969120149827\n",
      "BASELINE clustering mean score: 0.08897856287357343\n",
      "BASELINE archipelago mean score: 0.052713106135909224\n"
     ]
    }
   ],
   "source": [
    "for name, score in all_baseline_scores.items():\n",
    "    print(f'BASELINE {name} mean score: {score.mean()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
