{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sentiment Labelling and Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "# For data manipulation and analysis\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# For text preprocessing\n",
    "import re\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "import datetime\n",
    "import string\n",
    "\n",
    "# For multilabel classification\n",
    "from sklearn.preprocessing import MultiLabelBinarizer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.multiclass import OneVsRestClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "import os\n",
    "\n",
    "# For neural networks\n",
    "\n",
    "\n",
    "\n",
    "# For model evaluation\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1) Reading in the tags dataframe, get the unique set, subset appropriately\n",
    "- NEEDS TO CONTAIN GLOVE_VEC because moDAL queries based on glove_vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "tags_read = pd.read_csv(\"../dataset/tags_withglovevec.csv\")\n",
    "\n",
    "tags = tags_read\n",
    "\n",
    "tags['tag'] = tags['tag'].astype('str')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/tk/x3sjpph95kz4ghcssq1py2zc0000gn/T/ipykernel_47722/3865019710.py:12: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  tags['glove_vec'] = tags['glove_vec'].apply(string_to_vector)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "tags = tags[tags['has_glove_vec'] == True] # needs to contain glove vector\n",
    "\n",
    "\n",
    "def string_to_vector(s):\n",
    "    # Remove unwanted characters and split by spaces\n",
    "    numbers = s.strip('[]').split()\n",
    "    \n",
    "    # Convert to float and then to numpy array\n",
    "    return np.array([float(num) for num in numbers])\n",
    "\n",
    "tags['glove_vec'] = tags['glove_vec'].apply(string_to_vector)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>Unnamed: 0.1</th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>tag</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>un-lemmatised</th>\n",
       "      <th>glove_vec</th>\n",
       "      <th>has_glove_vec</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7</td>\n",
       "      <td>266</td>\n",
       "      <td>318</td>\n",
       "      <td>260</td>\n",
       "      <td>s</td>\n",
       "      <td>2015-02-20 22:42:49</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[0.18209, 0.88297, -0.49805, 0.53137, -0.36084...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>8</td>\n",
       "      <td>267</td>\n",
       "      <td>318</td>\n",
       "      <td>115149</td>\n",
       "      <td>action</td>\n",
       "      <td>2015-02-21 15:58:30</td>\n",
       "      <td>action</td>\n",
       "      <td>[0.02024, 0.84992, -0.7815, -0.82769, 0.43115,...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>15</td>\n",
       "      <td>274</td>\n",
       "      <td>320</td>\n",
       "      <td>2762</td>\n",
       "      <td>twist</td>\n",
       "      <td>2006-04-25 11:33:52</td>\n",
       "      <td>twist</td>\n",
       "      <td>[-0.095859, -0.17472, -0.034692, -0.37307, 0.3...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>16</td>\n",
       "      <td>275</td>\n",
       "      <td>320</td>\n",
       "      <td>2959</td>\n",
       "      <td>twist</td>\n",
       "      <td>2006-04-25 11:30:58</td>\n",
       "      <td>twist</td>\n",
       "      <td>[-0.095859, -0.17472, -0.034692, -0.37307, 0.3...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>17</td>\n",
       "      <td>276</td>\n",
       "      <td>320</td>\n",
       "      <td>3996</td>\n",
       "      <td>overrate</td>\n",
       "      <td>2006-04-25 11:32:28</td>\n",
       "      <td>overrated</td>\n",
       "      <td>[0.28151, -0.42171, -0.38275, 0.15364, -0.7648...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50186</th>\n",
       "      <td>109306</td>\n",
       "      <td>390955</td>\n",
       "      <td>138280</td>\n",
       "      <td>116797</td>\n",
       "      <td>history</td>\n",
       "      <td>2015-01-30 23:07:25</td>\n",
       "      <td>history</td>\n",
       "      <td>[0.045847, 0.074334, 0.015092, -0.26392, 0.155...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50187</th>\n",
       "      <td>109307</td>\n",
       "      <td>390956</td>\n",
       "      <td>138280</td>\n",
       "      <td>116797</td>\n",
       "      <td>informatics</td>\n",
       "      <td>2015-01-30 23:07:35</td>\n",
       "      <td>informatics</td>\n",
       "      <td>[0.17728, 0.15395, 0.77811, 0.16527, 1.2438, 0...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50188</th>\n",
       "      <td>109308</td>\n",
       "      <td>390957</td>\n",
       "      <td>138280</td>\n",
       "      <td>116797</td>\n",
       "      <td>mathematics</td>\n",
       "      <td>2015-01-30 23:07:17</td>\n",
       "      <td>mathematics</td>\n",
       "      <td>[1.0033, 0.38874, 0.64312, -0.6863, 0.93268, 0...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50189</th>\n",
       "      <td>109310</td>\n",
       "      <td>390959</td>\n",
       "      <td>138280</td>\n",
       "      <td>117871</td>\n",
       "      <td>image</td>\n",
       "      <td>2015-01-30 23:09:16</td>\n",
       "      <td>image</td>\n",
       "      <td>[0.011091, 0.48461, 0.019142, 0.083725, 0.5027...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50190</th>\n",
       "      <td>109311</td>\n",
       "      <td>390960</td>\n",
       "      <td>138280</td>\n",
       "      <td>117871</td>\n",
       "      <td>story</td>\n",
       "      <td>2015-01-30 23:09:25</td>\n",
       "      <td>story</td>\n",
       "      <td>[-0.35058, 0.58245, -0.065584, -0.41768, 0.224...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>49114 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       Unnamed: 0  Unnamed: 0.1  userId  movieId          tag  \\\n",
       "0               7           266     318      260            s   \n",
       "1               8           267     318   115149       action   \n",
       "2              15           274     320     2762        twist   \n",
       "3              16           275     320     2959        twist   \n",
       "4              17           276     320     3996     overrate   \n",
       "...           ...           ...     ...      ...          ...   \n",
       "50186      109306        390955  138280   116797      history   \n",
       "50187      109307        390956  138280   116797  informatics   \n",
       "50188      109308        390957  138280   116797  mathematics   \n",
       "50189      109310        390959  138280   117871        image   \n",
       "50190      109311        390960  138280   117871        story   \n",
       "\n",
       "                 timestamp un-lemmatised  \\\n",
       "0      2015-02-20 22:42:49           NaN   \n",
       "1      2015-02-21 15:58:30        action   \n",
       "2      2006-04-25 11:33:52         twist   \n",
       "3      2006-04-25 11:30:58         twist   \n",
       "4      2006-04-25 11:32:28     overrated   \n",
       "...                    ...           ...   \n",
       "50186  2015-01-30 23:07:25       history   \n",
       "50187  2015-01-30 23:07:35   informatics   \n",
       "50188  2015-01-30 23:07:17   mathematics   \n",
       "50189  2015-01-30 23:09:16         image   \n",
       "50190  2015-01-30 23:09:25         story   \n",
       "\n",
       "                                               glove_vec  has_glove_vec  \n",
       "0      [0.18209, 0.88297, -0.49805, 0.53137, -0.36084...           True  \n",
       "1      [0.02024, 0.84992, -0.7815, -0.82769, 0.43115,...           True  \n",
       "2      [-0.095859, -0.17472, -0.034692, -0.37307, 0.3...           True  \n",
       "3      [-0.095859, -0.17472, -0.034692, -0.37307, 0.3...           True  \n",
       "4      [0.28151, -0.42171, -0.38275, 0.15364, -0.7648...           True  \n",
       "...                                                  ...            ...  \n",
       "50186  [0.045847, 0.074334, 0.015092, -0.26392, 0.155...           True  \n",
       "50187  [0.17728, 0.15395, 0.77811, 0.16527, 1.2438, 0...           True  \n",
       "50188  [1.0033, 0.38874, 0.64312, -0.6863, 0.93268, 0...           True  \n",
       "50189  [0.011091, 0.48461, 0.019142, 0.083725, 0.5027...           True  \n",
       "50190  [-0.35058, 0.58245, -0.065584, -0.41768, 0.224...           True  \n",
       "\n",
       "[49114 rows x 9 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(49114, 9)\n"
     ]
    }
   ],
   "source": [
    "print(tags.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "small_sample = tags.sample(1000)\n",
    "small_sample.drop_duplicates(subset=['tag'], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/tk/x3sjpph95kz4ghcssq1py2zc0000gn/T/ipykernel_47722/2127020140.py:1: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  tags.drop_duplicates(subset=['tag'], inplace=True)\n"
     ]
    }
   ],
   "source": [
    "tags.drop_duplicates(subset=['tag'], inplace=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('Unique tags', 2845)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"Unique tags\",  len(tags) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Subset the dataset\n",
    "- No. of unique tags: 2639\n",
    "\n",
    "    Taking a random subset of 40% of 2639 samples for both fine-tuning and testing we get total: 1056 samples to use\n",
    "\n",
    "We use MoDAL to query this 40% first, then \n",
    "\n",
    "    - 75% for fine-tuning = 792 samples \n",
    "    - 25% for testing = 264 samples for testing\n",
    "\n",
    "+-----------------------------+-------------+-----------+\n",
    "|            Subset           | Sample Size | Percentage|\n",
    "+-----------------------------+-------------+-----------+\n",
    "| No. of unique tags          |   2639      |   100%    |\n",
    "+-----------------------------+-------------+-----------+\n",
    "| Subset (40% of Total)       |   1056      |   40%     |\n",
    "+-----------------------------+-------------+-----------+\n",
    "| Manually labelled           |   1056      |   40%     |\n",
    "+-----------------------------+-------------+-----------+\n",
    "| Fine-tuning (75% of Subset) |    792      |   75%     |\n",
    "+-----------------------------+-------------+-----------+\n",
    "| Testing (25% of Subset)     |    264      |   25%     |\n",
    "+-----------------------------+-------------+-----------+\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Labelling the fine-tuning and testing portions -> need to manually label \n",
    "- use MoDAL to query\n",
    "- AutoLabeller to label rest of subset "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**MoDAL**\n",
    " Method 2) Active learning: label a few samples, use these labels to suggest other unlabelled examples. \n",
    "\n",
    " - ##### RUN ONCE, THEN DON'T RUN AGAIN because this will produce random initial samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1138"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subset = tags.sample(frac=0.40, random_state=42)\n",
    "len(subset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Name: modAL-python\r\n",
      "Version: 0.4.2\r\n",
      "Summary: A modular active learning framework for Python3\r\n",
      "Home-page: https://modAL-python.github.io/\r\n",
      "Author: Tivadar Danka\r\n",
      "Author-email: 85a5187a@opayq.com\r\n",
      "License: MIT\r\n",
      "Location: /Users/jiayi/anaconda3/lib/python3.11/site-packages\r\n",
      "Requires: numpy, pandas, scikit-learn, scipy, skorch\r\n",
      "Required-by: \r\n"
     ]
    }
   ],
   "source": [
    "# !pip install git+https://github.com/modAL-python/modAL.git\n",
    "# !pip show modAL-python\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Please manually label the initial set of 10 samples.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from modAL.models import ActiveLearner\n",
    "from modAL.uncertainty import uncertainty_sampling\n",
    "\n",
    "# Initial labeled set\n",
    "initial_size = 10\n",
    "labeled_df = subset.sample(n=initial_size)\n",
    "subset.drop(labeled_df.index, inplace=True)\n",
    "\n",
    "# Wait for initial labeled data to be annotated manually  ------------- THIS HAS BEEN DONE\n",
    "print(f\"Please manually label the initial set of {initial_size} samples.\")\n",
    "labeled_df.to_csv(\"../dataset/to_label/initial_samples_to_label.csv\")\n",
    "\n",
    "# # # Pause here until the initial labeled samples are done\n",
    "# # input(f\"Once you've labeled the initial set of samples and saved the file, type 'done' to continue: \")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Please label iteration 7 samples.\n",
      "Once you've labeled the samples and saved the file, type 'done' to continue: done\n",
      "Please label iteration 8 samples.\n",
      "Once you've labeled the samples and saved the file, type 'done' to continue: done\n",
      "Please label iteration 9 samples.\n",
      "Once you've labeled the samples and saved the file, type 'done' to continue: done\n",
      "Please label iteration 10 samples.\n",
      "Once you've labeled the samples and saved the file, type 'done' to continue: sonw\n"
     ]
    }
   ],
   "source": [
    "# Read back annotated data\n",
    "labeled_df = pd.read_csv(\"../dataset/to_label/initial_samples_to_label_done.csv\")\n",
    "# After reading the annotated data\n",
    "labeled_df['glove_vec'] = labeled_df['glove_vec'].apply(string_to_vector)\n",
    "\n",
    "\n",
    "# Data preparation\n",
    "X_train = np.stack(labeled_df['glove_vec'].to_numpy())\n",
    "y_train = labeled_df['manual_label'].to_numpy()\n",
    "\n",
    "# Active learner initialization\n",
    "learner = ActiveLearner(\n",
    "    estimator=RandomForestClassifier(),\n",
    "    query_strategy=uncertainty_sampling,\n",
    "    X_training=X_train,\n",
    "    y_training=y_train\n",
    ")\n",
    "num_iterations = 10\n",
    "for i in range(10):\n",
    "    X_unlabeled = pd.DataFrame(subset['glove_vec'].tolist()).to_numpy()\n",
    "    X_unlabeled = X_unlabeled[~np.isnan(X_unlabeled).any(axis=1)]\n",
    "    \n",
    "    query_idx, query_sample = learner.query(X_unlabeled, n_instances=10)\n",
    "    \n",
    "    sample_to_label = subset.iloc[query_idx]\n",
    "    tags_to_remove = sample_to_label['tag'].tolist()  # Get the list of 'tag' values to remove\n",
    "    \n",
    "    sample_to_label.to_csv(f\"../dataset/to_label/samples_to_label_{i + 1}.csv\")\n",
    "    \n",
    "    # Pause for manual labeling\n",
    "    print(f\"Please label iteration {i + 1} samples.\")\n",
    "    \n",
    "    input(\"Once you've labeled the samples and saved the file, type 'done' to continue: \")\n",
    "\n",
    "    annotated_samples = pd.read_csv(f\"../dataset/to_label/samples_to_label_{i + 1}_done.csv\")\n",
    "    annotated_samples['glove_vec'] = annotated_samples['glove_vec'].apply(string_to_vector)\n",
    "\n",
    "    new_labels = annotated_samples['manual_label'].values\n",
    "    learner.teach(query_sample, new_labels)\n",
    "    \n",
    "    # Remove labeled samples from the subset for the next iteration\n",
    "    subset = subset[~subset['tag'].isin(tags_to_remove)]  # Drop by 'tag'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Combining the labelled samples into one dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "\n",
    "# Define the number of files you have (for this example, let's say you have 5 files)\n",
    "num_files = 10\n",
    "\n",
    "# Create an empty list to store individual dataframes\n",
    "dfs = []\n",
    "\n",
    "# Loop through each file, read its content, and append it to the dfs list\n",
    "for i in range(num_files):\n",
    "    file_path = f\"../dataset/to_label/samples_to_label_{i + 1}_done.csv\"\n",
    "    df = pd.read_csv(file_path)\n",
    "    dfs.append(df)\n",
    "\n",
    "# Concatenate all dataframes in the dfs list\n",
    "annotated_samples = pd.concat(dfs, ignore_index=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0.2</th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>Unnamed: 0.1</th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>tag</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>un-lemmatised</th>\n",
       "      <th>glove_vec</th>\n",
       "      <th>has_glove_vec</th>\n",
       "      <th>manual_label</th>\n",
       "      <th>Unnamed: 11</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5036</td>\n",
       "      <td>10069</td>\n",
       "      <td>26471</td>\n",
       "      <td>7434</td>\n",
       "      <td>98809</td>\n",
       "      <td>exaggerated</td>\n",
       "      <td>2013/5/21 21:02</td>\n",
       "      <td>exaggerated</td>\n",
       "      <td>[ 0.40403   -0.59134    0.063463  -0.12254    ...</td>\n",
       "      <td>True</td>\n",
       "      <td>negative</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>16671</td>\n",
       "      <td>35520</td>\n",
       "      <td>96184</td>\n",
       "      <td>27898</td>\n",
       "      <td>89030</td>\n",
       "      <td>trivial</td>\n",
       "      <td>2014/11/2 20:16</td>\n",
       "      <td>trivial</td>\n",
       "      <td>[ 0.35444    0.70451    0.21284   -0.47087    ...</td>\n",
       "      <td>True</td>\n",
       "      <td>negative</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3123</td>\n",
       "      <td>5700</td>\n",
       "      <td>19104</td>\n",
       "      <td>4550</td>\n",
       "      <td>2353</td>\n",
       "      <td>nonsensical</td>\n",
       "      <td>2014/7/15 07:52</td>\n",
       "      <td>nonsensical</td>\n",
       "      <td>[ 0.70851   0.036539  0.40354   0.035616  0.35...</td>\n",
       "      <td>True</td>\n",
       "      <td>negative</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>33810</td>\n",
       "      <td>74148</td>\n",
       "      <td>232951</td>\n",
       "      <td>82747</td>\n",
       "      <td>6666</td>\n",
       "      <td>meaningless</td>\n",
       "      <td>2013/4/16 09:09</td>\n",
       "      <td>meaningless</td>\n",
       "      <td>[ 4.5875e-01  5.4953e-01 -1.4093e-01  1.1588e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>negative</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>396</td>\n",
       "      <td>617</td>\n",
       "      <td>10068</td>\n",
       "      <td>2062</td>\n",
       "      <td>63072</td>\n",
       "      <td>pointless</td>\n",
       "      <td>2010/9/21 20:21</td>\n",
       "      <td>pointless</td>\n",
       "      <td>[-0.045273   0.20127   -0.060488   0.16788    ...</td>\n",
       "      <td>True</td>\n",
       "      <td>negative</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>366</td>\n",
       "      <td>566</td>\n",
       "      <td>10017</td>\n",
       "      <td>2062</td>\n",
       "      <td>1981</td>\n",
       "      <td>plot</td>\n",
       "      <td>2010/10/5 00:12</td>\n",
       "      <td>plot</td>\n",
       "      <td>[-3.0523e-01 -6.0794e-01 -4.5547e-02 -2.5703e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>18750</td>\n",
       "      <td>40027</td>\n",
       "      <td>126696</td>\n",
       "      <td>42290</td>\n",
       "      <td>50</td>\n",
       "      <td>sensational</td>\n",
       "      <td>2006/2/8 21:06</td>\n",
       "      <td>sensational</td>\n",
       "      <td>[-5.7019e-01  3.3709e-01 -8.3342e-01 -7.3405e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>negative</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>1517</td>\n",
       "      <td>2852</td>\n",
       "      <td>15409</td>\n",
       "      <td>4087</td>\n",
       "      <td>112183</td>\n",
       "      <td>pretentious</td>\n",
       "      <td>2015/2/6 23:19</td>\n",
       "      <td>pretentious</td>\n",
       "      <td>[ 7.6172e-01 -1.9364e-01 -2.0250e-01 -1.5498e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>9110</td>\n",
       "      <td>19474</td>\n",
       "      <td>59884</td>\n",
       "      <td>16724</td>\n",
       "      <td>54272</td>\n",
       "      <td>simpson</td>\n",
       "      <td>2008/6/18 01:06</td>\n",
       "      <td>simpsons</td>\n",
       "      <td>[-4.0774e-01  3.2613e-01 -6.3297e-01 -1.9871e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>4256</td>\n",
       "      <td>8250</td>\n",
       "      <td>24444</td>\n",
       "      <td>6988</td>\n",
       "      <td>2467</td>\n",
       "      <td>whodunit</td>\n",
       "      <td>2009/9/24 00:18</td>\n",
       "      <td>whodunit</td>\n",
       "      <td>[ 0.086844  -0.35796    0.29989   -0.29937    ...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    Unnamed: 0.2  Unnamed: 0  Unnamed: 0.1  userId  movieId          tag  \\\n",
       "0           5036       10069         26471    7434    98809  exaggerated   \n",
       "1          16671       35520         96184   27898    89030      trivial   \n",
       "2           3123        5700         19104    4550     2353  nonsensical   \n",
       "3          33810       74148        232951   82747     6666  meaningless   \n",
       "4            396         617         10068    2062    63072    pointless   \n",
       "..           ...         ...           ...     ...      ...          ...   \n",
       "95           366         566         10017    2062     1981         plot   \n",
       "96         18750       40027        126696   42290       50  sensational   \n",
       "97          1517        2852         15409    4087   112183  pretentious   \n",
       "98          9110       19474         59884   16724    54272      simpson   \n",
       "99          4256        8250         24444    6988     2467     whodunit   \n",
       "\n",
       "          timestamp un-lemmatised  \\\n",
       "0   2013/5/21 21:02   exaggerated   \n",
       "1   2014/11/2 20:16       trivial   \n",
       "2   2014/7/15 07:52   nonsensical   \n",
       "3   2013/4/16 09:09   meaningless   \n",
       "4   2010/9/21 20:21     pointless   \n",
       "..              ...           ...   \n",
       "95  2010/10/5 00:12          plot   \n",
       "96   2006/2/8 21:06   sensational   \n",
       "97   2015/2/6 23:19   pretentious   \n",
       "98  2008/6/18 01:06      simpsons   \n",
       "99  2009/9/24 00:18      whodunit   \n",
       "\n",
       "                                            glove_vec  has_glove_vec  \\\n",
       "0   [ 0.40403   -0.59134    0.063463  -0.12254    ...           True   \n",
       "1   [ 0.35444    0.70451    0.21284   -0.47087    ...           True   \n",
       "2   [ 0.70851   0.036539  0.40354   0.035616  0.35...           True   \n",
       "3   [ 4.5875e-01  5.4953e-01 -1.4093e-01  1.1588e-...           True   \n",
       "4   [-0.045273   0.20127   -0.060488   0.16788    ...           True   \n",
       "..                                                ...            ...   \n",
       "95  [-3.0523e-01 -6.0794e-01 -4.5547e-02 -2.5703e-...           True   \n",
       "96  [-5.7019e-01  3.3709e-01 -8.3342e-01 -7.3405e-...           True   \n",
       "97  [ 7.6172e-01 -1.9364e-01 -2.0250e-01 -1.5498e-...           True   \n",
       "98  [-4.0774e-01  3.2613e-01 -6.3297e-01 -1.9871e-...           True   \n",
       "99  [ 0.086844  -0.35796    0.29989   -0.29937    ...           True   \n",
       "\n",
       "   manual_label  Unnamed: 11  \n",
       "0      negative          NaN  \n",
       "1      negative          NaN  \n",
       "2      negative          NaN  \n",
       "3      negative          NaN  \n",
       "4      negative          NaN  \n",
       "..          ...          ...  \n",
       "95      neutral          NaN  \n",
       "96     negative          NaN  \n",
       "97      neutral          NaN  \n",
       "98      neutral          NaN  \n",
       "99      neutral          NaN  \n",
       "\n",
       "[100 rows x 12 columns]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "annotated_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "annotated_samples.drop(columns=['Unnamed: 0.2','Unnamed: 0.1', 'Unnamed: 0', 'userId', 'movieId', 'timestamp', 'un-lemmatised', 'glove_vec', 'has_glove_vec'], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "annotated_samples.to_csv('../dataset/annotated_samples_comb.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "rest_of_subset_label = subset.loc[~subset['tag'].isin(annotated_samples['tag'])]\n",
    "# this is the rest of the rows that need to be labelled\n",
    "\n",
    "\n",
    "rest_of_subset_label.to_csv(\"../dataset/rest_of_subset_label.csv\") # this is what will be labelled in the autolabelleer part"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Label the rest of the dataset - manually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "\n",
    "# reading in the file\n",
    "rest_of_subset_labelled = pd.read_csv(\"../dataset/rest_of_subset_label_done.csv\")\n",
    "rest_of_subset_labelled = rest_of_subset_labelled.rename(columns={'label' : 'manual_label'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "# join to annotated_samples\n",
    "annotated_samples = pd.read_csv(\"../dataset/annotated_samples_comb.csv\")\n",
    "\n",
    "df_labelled = pd.concat([rest_of_subset_labelled, annotated_samples], ignore_index=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>tag</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>un-lemmatised</th>\n",
       "      <th>glove_vec</th>\n",
       "      <th>has_glove_vec</th>\n",
       "      <th>manual_label</th>\n",
       "      <th>Unnamed: 11</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1339.0</td>\n",
       "      <td>4087.0</td>\n",
       "      <td>91128.0</td>\n",
       "      <td>caribbean</td>\n",
       "      <td>2012/11/3 20:42</td>\n",
       "      <td>caribbean</td>\n",
       "      <td>[ 0.010196   0.096712   0.011479  -0.72511    ...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>25963.0</td>\n",
       "      <td>58612.0</td>\n",
       "      <td>41997.0</td>\n",
       "      <td>athens</td>\n",
       "      <td>2007/3/18 13:52</td>\n",
       "      <td>athens</td>\n",
       "      <td>[ 3.5467e-01  1.6093e-01  1.2579e-01 -2.0789e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>6549.0</td>\n",
       "      <td>9815.0</td>\n",
       "      <td>5878.0</td>\n",
       "      <td>compassionate</td>\n",
       "      <td>2013/12/12 22:22</td>\n",
       "      <td>compassionate</td>\n",
       "      <td>[-0.69978   -0.031807  -0.14433    0.40873    ...</td>\n",
       "      <td>True</td>\n",
       "      <td>positive</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5383.0</td>\n",
       "      <td>8771.0</td>\n",
       "      <td>380.0</td>\n",
       "      <td>buddy</td>\n",
       "      <td>2006/1/21 05:59</td>\n",
       "      <td>buddy</td>\n",
       "      <td>[ 4.7541e-01 -2.7164e-01  8.2356e-02 -1.0867e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>positive</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>8213.0</td>\n",
       "      <td>14947.0</td>\n",
       "      <td>1968.0</td>\n",
       "      <td>clumsy</td>\n",
       "      <td>2010/11/2 22:27</td>\n",
       "      <td>clumsy</td>\n",
       "      <td>[ 0.47254   -0.29096   -0.27617   -0.26583   -...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1113</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>plot</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1114</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>sensational</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>negative</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1115</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>pretentious</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1116</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>simpson</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1117</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>whodunit</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1118 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      Unnamed: 0   userId  movieId            tag         timestamp  \\\n",
       "0         1339.0   4087.0  91128.0      caribbean   2012/11/3 20:42   \n",
       "1        25963.0  58612.0  41997.0         athens   2007/3/18 13:52   \n",
       "2         6549.0   9815.0   5878.0  compassionate  2013/12/12 22:22   \n",
       "3         5383.0   8771.0    380.0          buddy   2006/1/21 05:59   \n",
       "4         8213.0  14947.0   1968.0         clumsy   2010/11/2 22:27   \n",
       "...          ...      ...      ...            ...               ...   \n",
       "1113         NaN      NaN      NaN           plot               NaN   \n",
       "1114         NaN      NaN      NaN    sensational               NaN   \n",
       "1115         NaN      NaN      NaN    pretentious               NaN   \n",
       "1116         NaN      NaN      NaN        simpson               NaN   \n",
       "1117         NaN      NaN      NaN       whodunit               NaN   \n",
       "\n",
       "      un-lemmatised                                          glove_vec  \\\n",
       "0         caribbean  [ 0.010196   0.096712   0.011479  -0.72511    ...   \n",
       "1            athens  [ 3.5467e-01  1.6093e-01  1.2579e-01 -2.0789e-...   \n",
       "2     compassionate  [-0.69978   -0.031807  -0.14433    0.40873    ...   \n",
       "3             buddy  [ 4.7541e-01 -2.7164e-01  8.2356e-02 -1.0867e-...   \n",
       "4            clumsy  [ 0.47254   -0.29096   -0.27617   -0.26583   -...   \n",
       "...             ...                                                ...   \n",
       "1113            NaN                                                NaN   \n",
       "1114            NaN                                                NaN   \n",
       "1115            NaN                                                NaN   \n",
       "1116            NaN                                                NaN   \n",
       "1117            NaN                                                NaN   \n",
       "\n",
       "     has_glove_vec manual_label  Unnamed: 11  \n",
       "0             True      neutral          NaN  \n",
       "1             True      neutral          NaN  \n",
       "2             True     positive          NaN  \n",
       "3             True     positive          NaN  \n",
       "4             True      neutral          NaN  \n",
       "...            ...          ...          ...  \n",
       "1113           NaN      neutral          NaN  \n",
       "1114           NaN     negative          NaN  \n",
       "1115           NaN      neutral          NaN  \n",
       "1116           NaN      neutral          NaN  \n",
       "1117           NaN      neutral          NaN  \n",
       "\n",
       "[1118 rows x 10 columns]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_labelled"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sentiment Model - cardiffnlp/twitter-roberta-base-sentiment-latest using the \"fine_tune\" set\n",
    "\n",
    "- Positive, Negative, Neutral\n",
    "\n",
    "\n",
    "Only run the below if you need to rerun the sentiment again (takes approx 211minutes)\n",
    "- Assigns a sentiment label, number to each tag from the tags dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.4.0\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "# !pip install torch torchvision torchaudio\n",
    "import torch\n",
    "print(torch.__version__)\n",
    "print(torch.cuda.is_available())  # 如果安装了 CUDA，可以检查它是否可用\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jiayi/anaconda3/lib/python3.11/site-packages/transformers/utils/generic.py:260: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
      "  torch.utils._pytree._register_pytree_node(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ba291e82fa94ac8a2a21b1d7b64b718",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading pytorch_model.bin:   0%|          | 0.00/501M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jiayi/anaconda3/lib/python3.11/site-packages/transformers/modeling_utils.py:479: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  return torch.load(checkpoint_file, map_location=map_location)\n",
      "Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    }
   ],
   "source": [
    "# code ref: https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment-latest\n",
    "\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers import TFAutoModelForSequenceClassification\n",
    "from transformers import AutoTokenizer, AutoConfig\n",
    "import numpy as np\n",
    "from scipy.special import softmax\n",
    "\n",
    "MODEL = f\"cardiffnlp/twitter-roberta-base-sentiment-latest\"\n",
    "# MODEL = TFAutoModelForSequenceClassification.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "config = AutoConfig.from_pretrained(MODEL)\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(MODEL)\n",
    "\n",
    "\n",
    "def roberta_sentiment(tag):\n",
    "    enc_inp = tokenizer(tag, return_tensors='pt', padding=True, truncation=True)\n",
    "    sentiment = model(**enc_inp)\n",
    "\n",
    "    scores = sentiment.logits.detach().numpy()[0]\n",
    "    scores = softmax(scores)\n",
    "\n",
    "    \n",
    "\n",
    "    # label and scores:\n",
    "    ranking = np.argsort(scores)\n",
    "    ranking = ranking[::-1]\n",
    "    top_label = config.id2label[ranking[0]]\n",
    "    top_score = np.round(float(scores[ranking[0]]), 4)\n",
    "    \n",
    "    return (top_label, top_score)\n",
    "\n",
    "\n",
    "tags['sentiment_roberta'] = tags['tag'].apply(roberta_sentiment)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Output file of tag, sentiment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "tags[[\"tag\", \"sentiment_roberta\"]].to_csv(\"../dataset/robertaTest.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scaling the Sentiment Scores \n",
    "\n",
    "Raw sentiment score labels with ranges:\n",
    "\n",
    "- Neutral (0,1)\n",
    "- Positive (0,1)\n",
    "- Negative (0,1)\n",
    "\n",
    "Need to change these so that all are on a 0,1 scale\n",
    "\n",
    "Where:\n",
    "\n",
    "- Neutral = 0.5\n",
    "- Positive = 1\n",
    "- Negative = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "# loading back in the sentiment results\n",
    "sentiment_roberta = pd.read_csv(\"../dataset/robertaTest.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>tag</th>\n",
       "      <th>sentiment_roberta</th>\n",
       "      <th>sentiment_label</th>\n",
       "      <th>sentiment_value</th>\n",
       "      <th>scaled_sentiment_value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>10</td>\n",
       "      <td>hero</td>\n",
       "      <td>('positive', 0.6957)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.6957</td>\n",
       "      <td>0.84785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>60</th>\n",
       "      <td>60</td>\n",
       "      <td>humorous</td>\n",
       "      <td>('positive', 0.5959)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.5959</td>\n",
       "      <td>0.79795</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>64</td>\n",
       "      <td>classic</td>\n",
       "      <td>('positive', 0.5635)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.5635</td>\n",
       "      <td>0.78175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>68</th>\n",
       "      <td>68</td>\n",
       "      <td>classic</td>\n",
       "      <td>('positive', 0.5635)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.5635</td>\n",
       "      <td>0.78175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>93</th>\n",
       "      <td>93</td>\n",
       "      <td>nature</td>\n",
       "      <td>('positive', 0.5058)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.5058</td>\n",
       "      <td>0.75290</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1855</th>\n",
       "      <td>1855</td>\n",
       "      <td>classic</td>\n",
       "      <td>('positive', 0.5635)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.5635</td>\n",
       "      <td>0.78175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1856</th>\n",
       "      <td>1856</td>\n",
       "      <td>classic</td>\n",
       "      <td>('positive', 0.5635)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.5635</td>\n",
       "      <td>0.78175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1863</th>\n",
       "      <td>1863</td>\n",
       "      <td>christmas</td>\n",
       "      <td>('positive', 0.4902)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.4902</td>\n",
       "      <td>0.74510</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1879</th>\n",
       "      <td>1879</td>\n",
       "      <td>inspirational</td>\n",
       "      <td>('positive', 0.6262)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.6262</td>\n",
       "      <td>0.81310</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1902</th>\n",
       "      <td>1902</td>\n",
       "      <td>christmas</td>\n",
       "      <td>('positive', 0.4902)</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.4902</td>\n",
       "      <td>0.74510</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      Unnamed: 0            tag     sentiment_roberta sentiment_label  \\\n",
       "10            10           hero  ('positive', 0.6957)        positive   \n",
       "60            60       humorous  ('positive', 0.5959)        positive   \n",
       "64            64        classic  ('positive', 0.5635)        positive   \n",
       "68            68        classic  ('positive', 0.5635)        positive   \n",
       "93            93         nature  ('positive', 0.5058)        positive   \n",
       "...          ...            ...                   ...             ...   \n",
       "1855        1855        classic  ('positive', 0.5635)        positive   \n",
       "1856        1856        classic  ('positive', 0.5635)        positive   \n",
       "1863        1863      christmas  ('positive', 0.4902)        positive   \n",
       "1879        1879  inspirational  ('positive', 0.6262)        positive   \n",
       "1902        1902      christmas  ('positive', 0.4902)        positive   \n",
       "\n",
       "      sentiment_value  scaled_sentiment_value  \n",
       "10             0.6957                 0.84785  \n",
       "60             0.5959                 0.79795  \n",
       "64             0.5635                 0.78175  \n",
       "68             0.5635                 0.78175  \n",
       "93             0.5058                 0.75290  \n",
       "...               ...                     ...  \n",
       "1855           0.5635                 0.78175  \n",
       "1856           0.5635                 0.78175  \n",
       "1863           0.4902                 0.74510  \n",
       "1879           0.6262                 0.81310  \n",
       "1902           0.4902                 0.74510  \n",
       "\n",
       "[100 rows x 6 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_s = sentiment_roberta\n",
    "df_s[['sentiment_label', 'sentiment_value']] = df_s['sentiment_roberta'].str.extract(r\"\\('(.+)', ([\\d.]+)\\)\")\n",
    "\n",
    "# Convert sentiment_value to float\n",
    "df_s['sentiment_value'] = df_s['sentiment_value'].astype(float)\n",
    "\n",
    "# Initialize a new column for scaled sentiment values\n",
    "df_s['scaled_sentiment_value'] = np.nan\n",
    "\n",
    "# For 'neutral' labels, scale the value so that stronger neutral scores are closer to 0\n",
    "df_s.loc[df_s['sentiment_label'] == 'neutral', 'scaled_sentiment_value'] = (1 - df_s['sentiment_value']) * 0.5\n",
    "\n",
    "# For 'positive' labels, scale the value so that stronger positive scores are closer to 1\n",
    "df_s.loc[df_s['sentiment_label'] == 'positive', 'scaled_sentiment_value'] = (df_s['sentiment_value'] * 0.5) + 0.5\n",
    "\n",
    "# For 'negative' labels, scale the value to [-1 if sentiment_value is 1, otherwise scale to (-value)]\n",
    "df_s.loc[(df_s['sentiment_label'] == 'negative') & (df_s['sentiment_value'] == 1), 'scaled_sentiment_value'] = -1\n",
    "df_s.loc[(df_s['sentiment_label'] == 'negative') & (df_s['sentiment_value'] < 1), 'scaled_sentiment_value'] = -df_s['sentiment_value']\n",
    "\n",
    "# Display the DataFrame\n",
    "df_s[df_s['sentiment_label'] == 'positive'].head(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Output the df_s to a file that is read in by the CB model and also the sentimentMod.ipynb for clustering and CF prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "df_s.to_csv('../dataset/df_tag_sentiment.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>tag</th>\n",
       "      <th>sentiment_roberta</th>\n",
       "      <th>sentiment_label</th>\n",
       "      <th>sentiment_value</th>\n",
       "      <th>scaled_sentiment_value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>s</td>\n",
       "      <td>('neutral', 0.5123)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.5123</td>\n",
       "      <td>0.24385</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>action</td>\n",
       "      <td>('neutral', 0.6874)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.6874</td>\n",
       "      <td>0.15630</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>twist</td>\n",
       "      <td>('neutral', 0.7575)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.7575</td>\n",
       "      <td>0.12125</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>twist</td>\n",
       "      <td>('neutral', 0.7575)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.7575</td>\n",
       "      <td>0.12125</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>overrate</td>\n",
       "      <td>('neutral', 0.5669)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.5669</td>\n",
       "      <td>0.21655</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>violent</td>\n",
       "      <td>('neutral', 0.485)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.4850</td>\n",
       "      <td>0.25750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>violent</td>\n",
       "      <td>('neutral', 0.485)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.4850</td>\n",
       "      <td>0.25750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>end</td>\n",
       "      <td>('neutral', 0.5769)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.5769</td>\n",
       "      <td>0.21155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>music</td>\n",
       "      <td>('neutral', 0.4884)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.4884</td>\n",
       "      <td>0.25580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>fight</td>\n",
       "      <td>('neutral', 0.6423)</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.6423</td>\n",
       "      <td>0.17885</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0       tag    sentiment_roberta sentiment_label  sentiment_value  \\\n",
       "0           0         s  ('neutral', 0.5123)         neutral           0.5123   \n",
       "1           1    action  ('neutral', 0.6874)         neutral           0.6874   \n",
       "2           2     twist  ('neutral', 0.7575)         neutral           0.7575   \n",
       "3           3     twist  ('neutral', 0.7575)         neutral           0.7575   \n",
       "4           4  overrate  ('neutral', 0.5669)         neutral           0.5669   \n",
       "5           5   violent   ('neutral', 0.485)         neutral           0.4850   \n",
       "6           6   violent   ('neutral', 0.485)         neutral           0.4850   \n",
       "7           7       end  ('neutral', 0.5769)         neutral           0.5769   \n",
       "8           8     music  ('neutral', 0.4884)         neutral           0.4884   \n",
       "9           9     fight  ('neutral', 0.6423)         neutral           0.6423   \n",
       "\n",
       "   scaled_sentiment_value  \n",
       "0                 0.24385  \n",
       "1                 0.15630  \n",
       "2                 0.12125  \n",
       "3                 0.12125  \n",
       "4                 0.21655  \n",
       "5                 0.25750  \n",
       "6                 0.25750  \n",
       "7                 0.21155  \n",
       "8                 0.25580  \n",
       "9                 0.17885  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_s.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "\n",
    "df_labelled = df_labelled.rename(columns={'manual_label' : 'sentiment_label'})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>tag</th>\n",
       "      <th>sentiment_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>s</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>action</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>twist</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>overrate</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>violent</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50062</th>\n",
       "      <td>newspaper</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50110</th>\n",
       "      <td>static</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50139</th>\n",
       "      <td>repeat</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50142</th>\n",
       "      <td>seal</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50176</th>\n",
       "      <td>counterespionage</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2899 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                    tag sentiment_label\n",
       "0                     s         neutral\n",
       "1                action         neutral\n",
       "2                 twist         neutral\n",
       "4              overrate         neutral\n",
       "5               violent         neutral\n",
       "...                 ...             ...\n",
       "50062         newspaper         neutral\n",
       "50110            static        positive\n",
       "50139            repeat         neutral\n",
       "50142              seal         neutral\n",
       "50176  counterespionage         neutral\n",
       "\n",
       "[2899 rows x 2 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "### Evaluating sentiment model on manually labelled dataset\n",
    "df_s[[\"tag\", \"sentiment_label\"]].drop_duplicates()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>tag</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>un-lemmatised</th>\n",
       "      <th>glove_vec</th>\n",
       "      <th>has_glove_vec</th>\n",
       "      <th>sentiment_label</th>\n",
       "      <th>Unnamed: 11</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1339</td>\n",
       "      <td>4087.0</td>\n",
       "      <td>91128.0</td>\n",
       "      <td>caribbean</td>\n",
       "      <td>2012/11/3 20:42</td>\n",
       "      <td>caribbean</td>\n",
       "      <td>[ 0.010196   0.096712   0.011479  -0.72511    ...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>25963</td>\n",
       "      <td>58612.0</td>\n",
       "      <td>41997.0</td>\n",
       "      <td>athens</td>\n",
       "      <td>2007/3/18 13:52</td>\n",
       "      <td>athens</td>\n",
       "      <td>[ 3.5467e-01  1.6093e-01  1.2579e-01 -2.0789e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>6549</td>\n",
       "      <td>9815.0</td>\n",
       "      <td>5878.0</td>\n",
       "      <td>compassionate</td>\n",
       "      <td>2013/12/12 22:22</td>\n",
       "      <td>compassionate</td>\n",
       "      <td>[-0.69978   -0.031807  -0.14433    0.40873    ...</td>\n",
       "      <td>True</td>\n",
       "      <td>positive</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5383</td>\n",
       "      <td>8771.0</td>\n",
       "      <td>380.0</td>\n",
       "      <td>buddy</td>\n",
       "      <td>2006/1/21 05:59</td>\n",
       "      <td>buddy</td>\n",
       "      <td>[ 4.7541e-01 -2.7164e-01  8.2356e-02 -1.0867e-...</td>\n",
       "      <td>True</td>\n",
       "      <td>positive</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>8213</td>\n",
       "      <td>14947.0</td>\n",
       "      <td>1968.0</td>\n",
       "      <td>clumsy</td>\n",
       "      <td>2010/11/2 22:27</td>\n",
       "      <td>clumsy</td>\n",
       "      <td>[ 0.47254   -0.29096   -0.27617   -0.26583   -...</td>\n",
       "      <td>True</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1113</th>\n",
       "      <td>95</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>plot</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1114</th>\n",
       "      <td>96</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>sensational</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>negative</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1115</th>\n",
       "      <td>97</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>pretentious</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1116</th>\n",
       "      <td>98</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>simpson</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1117</th>\n",
       "      <td>99</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>whodunit</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1118 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      Unnamed: 0   userId  movieId            tag         timestamp  \\\n",
       "0           1339   4087.0  91128.0      caribbean   2012/11/3 20:42   \n",
       "1          25963  58612.0  41997.0         athens   2007/3/18 13:52   \n",
       "2           6549   9815.0   5878.0  compassionate  2013/12/12 22:22   \n",
       "3           5383   8771.0    380.0          buddy   2006/1/21 05:59   \n",
       "4           8213  14947.0   1968.0         clumsy   2010/11/2 22:27   \n",
       "...          ...      ...      ...            ...               ...   \n",
       "1113          95      NaN      NaN           plot               NaN   \n",
       "1114          96      NaN      NaN    sensational               NaN   \n",
       "1115          97      NaN      NaN    pretentious               NaN   \n",
       "1116          98      NaN      NaN        simpson               NaN   \n",
       "1117          99      NaN      NaN       whodunit               NaN   \n",
       "\n",
       "      un-lemmatised                                          glove_vec  \\\n",
       "0         caribbean  [ 0.010196   0.096712   0.011479  -0.72511    ...   \n",
       "1            athens  [ 3.5467e-01  1.6093e-01  1.2579e-01 -2.0789e-...   \n",
       "2     compassionate  [-0.69978   -0.031807  -0.14433    0.40873    ...   \n",
       "3             buddy  [ 4.7541e-01 -2.7164e-01  8.2356e-02 -1.0867e-...   \n",
       "4            clumsy  [ 0.47254   -0.29096   -0.27617   -0.26583   -...   \n",
       "...             ...                                                ...   \n",
       "1113            NaN                                                NaN   \n",
       "1114            NaN                                                NaN   \n",
       "1115            NaN                                                NaN   \n",
       "1116            NaN                                                NaN   \n",
       "1117            NaN                                                NaN   \n",
       "\n",
       "     has_glove_vec sentiment_label  Unnamed: 11  \n",
       "0             True         neutral          NaN  \n",
       "1             True         neutral          NaN  \n",
       "2             True        positive          NaN  \n",
       "3             True        positive          NaN  \n",
       "4             True         neutral          NaN  \n",
       "...            ...             ...          ...  \n",
       "1113           NaN         neutral          NaN  \n",
       "1114           NaN        negative          NaN  \n",
       "1115           NaN         neutral          NaN  \n",
       "1116           NaN         neutral          NaN  \n",
       "1117           NaN         neutral          NaN  \n",
       "\n",
       "[1118 rows x 10 columns]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_labelled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "is_executing": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 81.32%\n"
     ]
    }
   ],
   "source": [
    "merged_df = pd.merge(df_labelled, df_s, on='tag', how='inner', suffixes=('_test', '_df_s'))\n",
    "\n",
    "# Then, calculate the accuracy by comparing 'sentiment_label' from both dataframes\n",
    "correct_predictions = (merged_df['sentiment_label_test'] == merged_df['sentiment_label_df_s']).sum()\n",
    "total_predictions = len(merged_df)\n",
    "\n",
    "accuracy = correct_predictions / total_predictions if total_predictions != 0 else 0\n",
    "\n",
    "print(f'Accuracy: {accuracy * 100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
