{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695324378707,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "F7toc08bpdQ1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from transformers import RobertaModel, RobertaConfig, RobertaTokenizer\n",
    "\n",
    "from transformers import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score\n",
    "import textwrap\n",
    "import math\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1695323483676,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "MLk4JWizs4S9",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/train.csv')\n",
    "df_val = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/val.csv')\n",
    "df_test = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "YWiFI0a9QqEa",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_senti = pd.read_excel('/home/m_nsu/ICLR/Datasets/GoEmotion/sentiwords.xlsx')\n",
    "conditions = [\n",
    "    (df_senti['PosScore'] > df_senti['NegScore']),\n",
    "    (df_senti['PosScore'] < df_senti['NegScore']),\n",
    "    (df_senti['PosScore'] == df_senti['NegScore'])\n",
    "    ]\n",
    "\n",
    "values = ['Positive','Negative','Neutral']\n",
    "\n",
    "df_senti = df_senti[['PosScore','NegScore','Word','Definition']]\n",
    "df_senti['Sentiment'] = np.select(conditions, values)\n",
    "df_senti = df_senti.dropna(axis=0)\n",
    "df_senti.drop(columns=['PosScore', 'NegScore'], inplace=True)\n",
    "df_senti = df_senti[['Word', 'Sentiment', 'Definition']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 48,
     "status": "ok",
     "timestamp": 1695323754433,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "vX1DfZpN1W01",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train.dropna(inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "executionInfo": {
     "elapsed": 48,
     "status": "ok",
     "timestamp": 1695323754435,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "q_pTZIR8ltaJ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train.rename(columns={'Text': 'review'}, inplace=True)\n",
    "df_val.rename(columns={'Text': 'review'}, inplace=True)\n",
    "df_test.rename(columns={'Text': 'review'}, inplace=True)\n",
    "\n",
    "\n",
    "df_train.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)\n",
    "df_val.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)\n",
    "df_test.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 49,
     "status": "ok",
     "timestamp": 1695323754437,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e0NG9J-oq_wW",
    "outputId": "f1bb3869-364e-4c7a-8c3f-a62ce4b182c7",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>review</th>\n",
       "      <th>Emotions</th>\n",
       "      <th>sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>my favorite food is anything i did not have to...</td>\n",
       "      <td>neutral</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>now if he does off himself everyone will think...</td>\n",
       "      <td>neutral</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>why the fuck is bayless isoing</td>\n",
       "      <td>anger</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>to make her feel threatened</td>\n",
       "      <td>fear</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>dirty southern wankers</td>\n",
       "      <td>annoyance</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              review   Emotions  sentiment\n",
       "0  my favorite food is anything i did not have to...    neutral          3\n",
       "1  now if he does off himself everyone will think...    neutral          3\n",
       "2                     why the fuck is bayless isoing      anger          1\n",
       "3                        to make her feel threatened       fear          1\n",
       "4                             dirty southern wankers  annoyance          1"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 45,
     "status": "ok",
     "timestamp": 1695323754437,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "RN48HyYaC2VD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def filter_rows_by_values(df, col, values):\n",
    "    return df[~df[col].isin(values)]\n",
    "\n",
    "df_train = filter_rows_by_values(df_train,'sentiment',[2,3])\n",
    "df_test = filter_rows_by_values(df_test,'sentiment',[2,3])\n",
    "df_val = filter_rows_by_values(df_val,'sentiment',[2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 45,
     "status": "ok",
     "timestamp": 1695323754438,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Qm9XHUHuuSXq",
    "outputId": "6e7765c0-f186-4b8c-db5a-d954fbb3dc9a",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train['emotion'], map = pd.factorize(df_train['Emotions'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754438,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "TCTYc5frxZfx",
    "tags": []
   },
   "outputs": [],
   "source": [
    "emotion_map = dict(zip(map, range(len(map))))\n",
    "map_emotion = {v: k for k, v in emotion_map.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754439,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "_xxZFLJwxfpy",
    "outputId": "914e4d70-43c6-462e-dc72-954759c555eb",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'gratitude'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emotion_map['gratitude']  # Get Encoded Label from Categorical Label\n",
    "map_emotion[3]  # Get Categorical Label from Encoded Label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754440,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "k-a5Ww9MhpdO",
    "outputId": "ecf8deb0-f932-4d16-a103-9acc16f252ec",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_val['emotion'] = df_val[\"Emotions\"].apply(lambda x: emotion_map[x])\n",
    "df_test['emotion'] = df_test[\"Emotions\"].apply(lambda x: emotion_map[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695323756839,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UmEtmaFNrakz",
    "tags": []
   },
   "outputs": [],
   "source": [
    "MAX_LEN = 200\n",
    "RANDOM_SEED = 42\n",
    "device = torch.device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695323756840,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6oWORty0p8Xo",
    "outputId": "cb12350b-753a-4c61-c7e8-3c9427f000e9",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "PRE_TRAINED_MODEL_NAME = 'roberta-large'\n",
    "config = RobertaConfig.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "tokenizer = RobertaTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "clear_output()\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695323770760,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "OtZt1p7ys7XD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class IMDBDataset(Dataset):\n",
    "\n",
    "  def __init__(self, reviews, sentiments, tokenizer, max_len):\n",
    "    self.reviews = reviews\n",
    "    self.sentiments = sentiments\n",
    "    self.tokenizer = tokenizer\n",
    "    self.max_len = max_len\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.reviews)\n",
    "\n",
    "  def __getitem__(self, item):\n",
    "    review = str(self.reviews[item])\n",
    "    sentiment = self.sentiments[item]\n",
    "\n",
    "\n",
    "    encoding = self.tokenizer.encode_plus(\n",
    "      review,\n",
    "      add_special_tokens=True,\n",
    "      max_length=self.max_len,\n",
    "      return_token_type_ids=False,\n",
    "      padding='max_length',\n",
    "      truncation = True,\n",
    "      return_attention_mask=True,\n",
    "      return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    return {\n",
    "      'review': review,\n",
    "      'input_ids': encoding['input_ids'].flatten(),\n",
    "      'attention_mask': encoding['attention_mask'].flatten(),\n",
    "      'sentiments': torch.tensor(sentiment, dtype=torch.long),\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695325189268,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UuOujQajtL5f",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_data_loader(df, tokenizer, max_len, batch_size):\n",
    "  ds = IMDBDataset(\n",
    "    reviews=df.review.to_numpy(),\n",
    "    sentiments=df['sentiment'].to_numpy(),\n",
    "    tokenizer=tokenizer,\n",
    "    max_len=max_len\n",
    "  )\n",
    "\n",
    "  return DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=8\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695323770761,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "3zzA4eBytOqj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "\n",
    "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695325187523,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AoUfRPy0tQgk",
    "outputId": "4cc31257-8fe1-4cce-883e-8aff4ebf2185",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['review', 'input_ids', 'attention_mask', 'sentiments'])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = next(iter(train_data_loader))\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1695325688712,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Y3Nil-yatUqF",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Classifier(nn.Module):\n",
    "  def __init__(self):\n",
    "    super(Classifier, self).__init__()\n",
    "    self.bert = RobertaModel.from_pretrained(PRE_TRAINED_MODEL_NAME,config=config)\n",
    "    self.FC = nn.Linear(config.hidden_size,2, bias=False)\n",
    "\n",
    "\n",
    "  def forward(self, input_ids, attention_mask):\n",
    "    with torch.no_grad():\n",
    "      pooled_output = self.bert(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_dict = False\n",
    "      )\n",
    "    pooled_output = torch.mean(pooled_output[0], dim=1) # Taking Averge pooled last layer embedding\n",
    "\n",
    "    binary_out = self.FC(pooled_output)\n",
    "    \n",
    "    return binary_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "id": "HWZ37gsztWzL",
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = Classifier()\n",
    "model = model.to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1695325221207,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWy57v2CxxCM",
    "tags": []
   },
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    if name.startswith('bert'):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1695325221207,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e5iLu13CYlux",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#for name, param in model.named_parameters():\n",
    "#    print(name, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 21,
     "status": "ok",
     "timestamp": 1695325221208,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "nZpqz6yDtYZ4",
    "outputId": "c0ba4f97-cad5-46b8-e1ff-e818ff5b4035",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 200])\n",
      "torch.Size([32, 200])\n"
     ]
    }
   ],
   "source": [
    "input_ids = data['input_ids'].to(device)\n",
    "attention_mask = data['attention_mask'].to(device)\n",
    "sentiments = data['sentiments'].to(device)\n",
    "\n",
    "print(input_ids.shape) # batch size x seq length\n",
    "print(attention_mask.shape) # batch size x seq length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1695325221208,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NwvA-zp7vqc1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#del test\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "executionInfo": {
     "elapsed": 603,
     "status": "ok",
     "timestamp": 1695325221800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "-9Z37OXOtb0q",
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs = model(input_ids, attention_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1695325221800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NHFM3QUhg3n0",
    "outputId": "602bd955-db78-44a6-931e-e5e901bb1856",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0678,  0.3341],\n",
       "        [-0.0677,  0.2402],\n",
       "        [-0.1093,  0.2249],\n",
       "        [ 0.0137,  0.3333],\n",
       "        [ 0.0391,  0.2449],\n",
       "        [ 0.0517,  0.3534],\n",
       "        [-0.0368,  0.1094],\n",
       "        [ 0.0902,  0.3866],\n",
       "        [-0.1260,  0.2425],\n",
       "        [ 0.1266,  0.3623],\n",
       "        [-0.0791,  0.2829],\n",
       "        [ 0.1026,  0.3454],\n",
       "        [ 0.1434,  0.1547],\n",
       "        [ 0.0325,  0.2441],\n",
       "        [-0.0610,  0.1391],\n",
       "        [ 0.0164,  0.3234],\n",
       "        [ 0.1372,  0.3946],\n",
       "        [ 0.0945,  0.3396],\n",
       "        [ 0.0195,  0.2806],\n",
       "        [-0.0503,  0.2724],\n",
       "        [ 0.2340,  0.4261],\n",
       "        [-0.0060,  0.3402],\n",
       "        [ 0.0019,  0.3550],\n",
       "        [ 0.1309,  0.3965],\n",
       "        [ 0.0828,  0.2484],\n",
       "        [-0.0809,  0.2008],\n",
       "        [ 0.0539,  0.3457],\n",
       "        [ 0.0327,  0.4093],\n",
       "        [ 0.0664,  0.2318],\n",
       "        [-0.0216,  0.1343],\n",
       "        [-0.0825,  0.3160],\n",
       "        [ 0.0350,  0.3243]], device='cuda:0', grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1022,
     "status": "ok",
     "timestamp": 1695325845248,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cjMVWA5a_6lf",
    "outputId": "bae8764e-0ce8-4118-e676-267b2c3ac06b",
    "tags": []
   },
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=0.001)\n",
    "total_steps = len(train_data_loader) * EPOCHS\n",
    "\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "  optimizer,\n",
    "  num_warmup_steps=math.floor((1./5)*total_steps),\n",
    "  num_training_steps=total_steps\n",
    ")\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695325845969,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cLFDb4pzbx9W",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_epoch(\n",
    "  model,\n",
    "  data_loader,\n",
    "  loss_fn,\n",
    "  optimizer,\n",
    "  device,\n",
    "  scheduler,\n",
    "  n_examples\n",
    "):\n",
    "  model = model.train()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  for d in data_loader:\n",
    "    input_ids = d[\"input_ids\"].to(device)\n",
    "    attention_mask = d[\"attention_mask\"].to(device)\n",
    "    sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "    outputs = model(\n",
    "      input_ids=input_ids,\n",
    "      attention_mask=attention_mask\n",
    "    ).to(device)\n",
    "\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    loss = loss_fn(outputs, sentiments)\n",
    "\n",
    "    correct_predictions += torch.sum(preds == sentiments)\n",
    "    losses.append(loss.item())\n",
    "\n",
    "\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695325845971,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "z4GAdIawtUue",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def eval_model(model, data_loader, loss_fn, device, n_examples, on_new=False):\n",
    "  model = model.eval()\n",
    "\n",
    "  losses = []\n",
    "  f1s = []\n",
    "\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      ).to(device)\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      loss = loss_fn(outputs, sentiments)\n",
    "\n",
    "      correct_predictions += torch.sum(preds == sentiments)\n",
    "      losses.append(loss.item())\n",
    "\n",
    "      f1s.append(f1_score(sentiments.cpu(), preds.cpu(), average='macro'))\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses), np.mean(f1s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "executionInfo": {
     "elapsed": 3203929,
     "status": "ok",
     "timestamp": 1695329051836,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "IqdIHJsrANr0",
    "outputId": "ee79ec9a-a033-47e0-bc8b-a7b71698465c",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "----------\n",
      "Train loss 0.5958956672808523 accuracy 0.6533018867924528\n",
      "Val   loss 0.5034097489676898 accuracy 0.7516858389527965\n",
      "\n",
      "Epoch 2/8\n",
      "----------\n",
      "Train loss 0.4736947440220877 accuracy 0.7699217181854677\n",
      "Val   loss 0.43595584057554415 accuracy 0.7881792939309797\n",
      "\n",
      "Epoch 3/8\n",
      "----------\n",
      "Train loss 0.4278349260744275 accuracy 0.804245283018868\n",
      "Val   loss 0.41852075314219994 accuracy 0.7988893296310988\n",
      "\n",
      "Epoch 4/8\n",
      "----------\n",
      "Train loss 0.41174918166802743 accuracy 0.812826174227218\n",
      "Val   loss 0.40194029219542876 accuracy 0.8167393891312971\n",
      "\n",
      "Epoch 5/8\n",
      "----------\n",
      "Train loss 0.40224149146298155 accuracy 0.8183460457647531\n",
      "Val   loss 0.3998827128847943 accuracy 0.8199127330424435\n",
      "\n",
      "Epoch 6/8\n",
      "----------\n",
      "Train loss 0.39465676520743304 accuracy 0.8226615816940988\n",
      "Val   loss 0.3961127040129674 accuracy 0.82070606902023\n",
      "\n",
      "Epoch 7/8\n",
      "----------\n",
      "Train loss 0.39616249164455775 accuracy 0.8199518265756724\n",
      "Val   loss 0.39667220922965035 accuracy 0.8298294327647758\n",
      "\n",
      "Epoch 8/8\n",
      "----------\n",
      "Train loss 0.3944973769339091 accuracy 0.8205038137294259\n",
      "Val   loss 0.3959179220697548 accuracy 0.8302261007536692\n",
      "\n",
      "CPU times: user 1h 34min 35s, sys: 19 s, total: 1h 34min 54s\n",
      "Wall time: 1h 34min 59s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_a = []\n",
    "train_l = []\n",
    "val_a = []\n",
    "val_l = []\n",
    "best_accuracy = 0\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "  print(f'Epoch {epoch + 1}/{EPOCHS}')\n",
    "  print('-' * 10)\n",
    "\n",
    "  train_acc, train_loss = train_epoch(\n",
    "    model,\n",
    "    train_data_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device,\n",
    "    scheduler,\n",
    "    len(df_train)\n",
    "  )\n",
    "\n",
    "  print(f'Train loss {train_loss} accuracy {train_acc}')\n",
    "\n",
    "  val_acc, val_loss, val_f1 = eval_model(\n",
    "    model,\n",
    "    val_data_loader,\n",
    "    loss_fn,\n",
    "    device,\n",
    "    len(df_val)\n",
    "  )\n",
    "\n",
    "  print(f'Val   loss {val_loss} accuracy {val_acc}')\n",
    "  print()\n",
    "\n",
    "  train_a.append(train_acc)\n",
    "  train_l.append(train_loss)\n",
    "  val_a.append(val_acc)\n",
    "  val_l.append(val_loss)\n",
    "\n",
    "  if val_acc > best_accuracy:\n",
    "    torch.save(model.state_dict(), 'baseline_xlnet_best_model_state.bin')\n",
    "    best_accuracy = val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "executionInfo": {
     "elapsed": 51,
     "status": "ok",
     "timestamp": 1695329137842,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "FowMSU5U7SDQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_a = [i.item() for i in train_a]\n",
    "train_l = [i.item() for i in train_l]\n",
    "val_a = [i.item() for i in val_a]\n",
    "val_l = [i.item() for i in val_l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 480
    },
    "executionInfo": {
     "elapsed": 2888,
     "status": "ok",
     "timestamp": 1695329143450,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aUQbxyTEAPhM",
    "outputId": "b103ca88-1886-4f16-ea21-5b966a08fe7c",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHFCAYAAAAOmtghAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOzdeVxU5f4H8M+ZGRj2UQHZXEBxA8QNF/SiaamJa4tLVoqW99Km5nb1172lZmHmFhpkXXG/ZaZ2rSyztOSqN9PEJXHJMBDGEBWGRYGZOb8/BkbGAZmBgYHj5/16zYs5z3nOOd8z+brzuc955hxBFEURRERERBIhs3cBRERERLbEcENERESSwnBDREREksJwQ0RERJLCcENERESSwnBDREREksJwQ0RERJLCcENERESSwnBDREREksJwQyQhgiBY9Prhhx9scrw7d+5AEAQsXbq0Rtv36dMHjz76qE1qsZavry+efPLJavt98803EAQB//vf/6zaf3x8PLZu3VrT8oioFhT2LoCIbOfo0aMmy2+++SYOHjyIAwcOmLSHhITY5HhKpRJHjx5Fq1atarT9+vXrIZfLbVJLXYmMjMTRo0cRFhZm1Xbx8fEIDg7GM888U0eVEVFVGG6IJKRPnz4my97e3pDJZGbtVSkpKYFcLrc4cAiCYPG+KxMaGlrjbeuLSqWq1Tnakk6ng06ng6Ojo71LIWrQeFmK6AFVfrll+/btmD59Ovz8/ODk5ISMjAyo1WrExsaiU6dOcHV1hY+PDx555BGzkaHKLkt98MEHEAQBhw8fxrRp0+Dp6QkvLy+MHTsWf/75p8n2916WOn/+PARBwJo1a/DOO++gdevWcHNzQ79+/XDixAmzc0hISEBwcDCUSiU6d+6MHTt2YMKECejYsaPFn8MXX3yBrl27wtnZGSEhIWaXkiq7LHXx4kWMHTsWfn5+UCqV8PX1xeDBg/Hrr78CMFzyunz5Mvbt22e8FFixprS0NDz11FPw9vaGUqlESEgI4uPjUfE5xuWfxerVq7Fw4UIEBgbC0dERe/fuhZubG2bMmGF2LhcuXIBMJsOaNWssPn8iKeLIDdEDbvbs2ejfvz/+9a9/Qa/Xo2nTpkhPT4eDgwMWLVoEHx8f5OfnY8eOHYiKikJycjIiIyOr3e/kyZMxatQofPzxx0hLS8O8efMwZcoU7N27t9ptV65cic6dO2PNmjXQ6XR47bXXMGzYMKSlpcHV1RWA4bLPjBkzMGHCBMTHx+PmzZtYsGABSktL4ezsbNG5//zzz7hw4QLmz58PLy8vJCYm4tlnn0X79u3Rq1evSrcRRRGPPvoolEolli9fjpYtWyInJwfJycnIzc0FAOzduxejR49GixYtsGrVKgAw1qRWqxEZGQlBEBAXF4cWLVrg888/x4wZM3DlyhWsXLnS5HjLly9HSEgIVq5cCTc3N4SEhGDSpEnYtGkT3n77bePnAQDvv/8+XF1dMXnyZIvOn0iyRCKSrMmTJ4uurq6Vrvv6669FAOKQIUOq3Y9WqxVLS0vFfv36iU899ZSx/fbt2yIAMS4uztiWmJgoAhBnzZplso/FixeLAMSbN28a23r37i0OHTrUuJyamioCECMiIkS9Xm9sP3TokAhA3L17tyiKolhSUiJ6enqKAwYMMDnGb7/9JsrlcrFDhw7VnpOPj4/o6uoqZmVlGdsKCgpEd3d3ccaMGca28s/p6NGjoiiK4tWrV0UA4gcffHDf/bdt29bk3MrNnDlTFARBTElJMWmfMmWKKJPJxLS0NJPPolOnTqJWqzXpm5qaKgqCICYmJhrb8vPzRQ8PD/GFF16o9tyJpI6XpYgecE888YRZmyiKWLNmDbp16wYnJycoFAo4ODjg8OHDSE1NtWi/o0aNMlkODw8HAKSnp1e77YgRIyAIgtm2f/zxBwDg7NmzuHHjBsaNG2eyXdu2bdGzZ0+L6gOAnj17ws/Pz7js6uqKtm3bGo9TGV9fX7Rq1Qpvv/023nvvPZw6dQp6vd7iYx44cADdunVDly5dTNpjYmKg1+vNfsk2ZswYszlQHTt2xODBg/H+++8b2zZv3gyNRoOXXnrJ4lqIpIrhhugBV/HLvVxcXBymT5+OqKgo7Nq1Cz/99BN+/vlnDBo0CLdv37Zov56enibLSqUSACzavrptb9y4AQDw8fEx27ayNkuPU36s+9Uol8tx8OBBDBw4EG+99Ra6du0KHx8fzJo1C4WFhdUe88aNG5V+5v7+/sb1FVXWFwBmzJiBs2fP4tChQwAMl6QGDhzYKCZpE9U1zrkhesBVHCEpt3XrVjz66KOIj483ac/Ly6uvsu6rPJTcO0EZAK5du1bnx2/Tpg02btwIwDDx95NPPsGbb74JvV6P1atX33dbT09PqNVqs/asrCwAgJeXl0l7Zf99AGDYsGFo164d1q5dC61Wi3PnzmHx4sU1OBsi6eHIDRGZEQTBOFpS7vjx4/jll1/sVJGpsLAwNGvWDNu3bzdpv3z5Mo4fP16vtXTs2BELFy5E+/btTT6fqkaAHn74YaSkpBh/WVVu8+bNkMlkeOihhyw6riAIeOWVV7B792688cYbaNGiBcaMGVOrcyGSCoYbIjIzYsQIfPHFF1iyZAkOHDiAtWvXYvjw4QgMDLR3aQAABwcHvPHGGzh06BCeeuopfP3119i6dSuGDh0Kf39/yGR19z9tx44dw8CBA/H+++9j3759OHDgAObPn48LFy5g8ODBxn6dO3fG8ePH8dlnn+H48ePGMDN37lx4e3tj6NChSEpKwr59+/DSSy9h/fr1mDlzJlq3bm1xLTExMXBxccF///tfxMbGNvgbIhLVF16WIiIzCxcuRElJCRISEvDWW28hLCwMGzZswObNm5GSkmLv8gAA06dPh1wux8qVK7Fr1y60adMGixYtwpYtW6DRaOrsuC1atECrVq2wZs0aXL16FTKZDG3btkV8fDxefPFFY7+33noLOTk5mDJlCgoKCtChQwecP38efn5+OHr0KBYsWIC5c+ciPz8fbdu2xerVqzF9+nSranF3d0d0dDR27dqFadOm2fpUiRotQRQr3DWKiKgRu3HjBtq1a4dnnnnGbL6QFN2+fRutWrXCsGHDsHnzZnuXQ9RgcOSGiBql9PR0rFy5EgMGDECzZs2QlpaGFStWoLi4GK+88oq9y6tT2dnZuHjxItatW4dbt25h3rx59i6JqEFhuCGiRsnJyQmXLl3Cxx9/jJs3b8LNzQ19+/bFxo0b0a5dO3uXV6d27dqFF154AQEBAfjoo4+sfqgnkdTxshQRERFJil1/LXXo0CGMHDkS/v7+EAQBn3/+ebXb/Pjjj+jRowecnJzQpk0bfPDBB/VQKRERETUWdg03hYWF6NKlC9auXWtR/7S0NERHRyMqKgonT57E//3f/2H69OnYuXNnHVdKREREjUWDuSwlCAJ2795935tQ/f3vf8eePXtMnm0TGxuLU6dO4ejRo/VRJhERETVwjWpC8dGjRzFkyBCTtqFDh2L9+vUoLS2Fg4OD2TbFxcUoLi42Luv1ety8eROenp5V3taciIiIGhZRFJGfn2/RjTobVbi5du2a2UPxfHx8oNVqkZOTU+UDABctWlRfJRIREVEdysjIQIsWLe7bp1GFG8D8IXLlV9WqGoVZsGABZs2aZVzOy8tDq1atkJGRAQ8Pj7orlIiIiGxGo9GgZcuWcHd3r7Zvowo3vr6+Zk/8zc7OhkKhMD4l+F5KpdLsAYAA4OHhwXBDRETUyFgypaRRPTgzMjIS+/fvN2n79ttvERERUel8GyIiInrw2DXcFBQUICUlxfggvrS0NKSkpCA9PR2A4ZLSpEmTjP1jY2Pxxx9/YNasWUhNTUVSUhLWr1+POXPm2KV+IiIianjselnq+PHjGDhwoHG5fG7M5MmTsXHjRqjVamPQAYCgoCDs3bsXr776Kt5//334+/sjPj4eTzzxRL3XTkRERA1Tg7nPTX3RaDRQqVTIy8vjnBsiapR0Oh1KS0vtXQaRzTk6Olb5M29rvr8b1YRiIqIHmSiKuHbtGnJzc+1dClGdkMlkCAoKgqOjY632w3BDRNRIlAeb5s2bw8XFhTciJUnR6/XIysqCWq1Gq1atavXvm+GGiKgR0Ol0xmBT1a0viBo7b29vZGVlQavV1upX0I3qp+BERA+q8jk2Li4udq6EqO6UX47S6XS12g/DDRFRI8JLUSRltvr3zXBDREREksJwQ0REjUpgYCBWr15t7zKoAeOEYiIiqlMPPfQQunbtarNA8vPPP8PV1dUm+yJpYrghIiK7E0UROp0OCkX1X0ve3t71UFH9sub8qXq8LEVERHUmJiYGP/74I9577z0IggBBEHDlyhX88MMPEAQB+/btQ0REBJRKJZKTk3H58mWMHj0aPj4+cHNzQ8+ePfHdd9+Z7PPey1KCIOBf//oXHnvsMbi4uKBdu3bYs2fPfevaunUrIiIi4O7uDl9fX0ycOBHZ2dkmfX799VcMHz4cHh4ecHd3R1RUFC5fvmxcn5SUhNDQUCiVSvj5+eHll18GAFy5cgWCIBifmwgAubm5EAQBP/zwAwDU6vyLi4sxb948tGzZEkqlEu3atcP69eshiiKCg4OxfPlyk/5nz56FTCYzqV3qGG6IiBopURRRVKK1y8vSJ/e89957iIyMxLRp06BWq6FWq9GyZUvj+nnz5iEuLg6pqakIDw9HQUEBoqOj8d133+HkyZMYOnQoRo4cafKcwcosWrQI48aNw+nTpxEdHY2nn34aN2/erLJ/SUkJ3nzzTZw6dQqff/450tLSEBMTY1yfmZmJ/v37w8nJCQcOHMCJEycwdepUaLVaAEBiYiJeeukl/PWvf8WZM2ewZ88eBAcHW/SZVFST8580aRI++eQTxMfHIzU1FR988AHc3NwgCAKmTp2KDRs2mBwjKSkJUVFRaNu2rdX1NVYc/yIiaqRul+oQ8vo+uxz73OKhcHGs/itEpVLB0dERLi4u8PX1NVu/ePFiDB482Ljs6emJLl26GJeXLFmC3bt3Y8+ePcaRkcrExMTgqaeeAgC8/fbbWLNmDY4dO4ZHH3200v5Tp041vm/Tpg3i4+PRq1cvFBQUwM3NDe+//z5UKhU++eQT483k2rdvb1LX7NmzMWPGDGNbz549q/s4zFh7/hcvXsSnn36K/fv345FHHjHWX27KlCl4/fXXcezYMfTq1QulpaXYunUr3n33Xatra8w4ckNERHYTERFhslxYWIh58+YhJCQETZo0gZubG86fP1/tyE14eLjxvaurK9zd3c0uM1V08uRJjB49Gq1bt4a7uzseeughADAeJyUlBVFRUZXeJTc7OxtZWVl4+OGHLT3NKll7/ikpKZDL5RgwYECl+/Pz88Pw4cORlJQEAPjyyy9x584djB07tta1NiYcuSEiaqScHeQ4t3io3Y5tC/f+6mnu3LnYt28fli9fjuDgYDg7O+PJJ59ESUnJffdzbwgRBAF6vb7SvoWFhRgyZAiGDBmCrVu3wtvbG+np6Rg6dKjxOM7OzlUe637rABifal3x0l1VT3G39vyrOzYAPP/883j22WexatUqbNiwAePHj3/g7mzNcENE1EgJgmDRpSF7c3R0tPh2+snJyYiJicFjjz0GACgoKMCVK1dsWs/58+eRk5ODpUuXGuf/HD9+3KRPeHg4Nm3ahNLSUrPg5O7ujsDAQHz//fcYOHCg2f7Lf82lVqvRrVs3ADCZXHw/1Z1/586dodfr8eOPPxovS90rOjoarq6uSExMxNdff41Dhw5ZdGwp4WUpIiKqU4GBgfjpp59w5coV5OTkVDmiAgDBwcHYtWsXUlJScOrUKUycOPG+/WuiVatWcHR0xJo1a/D7779jz549ePPNN036vPzyy9BoNJgwYQKOHz+OS5cuYcuWLbhw4QIAYOHChVixYgXi4+Nx6dIl/PLLL1izZg0Aw+hKnz59sHTpUpw7dw6HDh3CP/7xD4tqq+78AwMDMXnyZEydOtU4EfqHH37Ap59+auwjl8sRExODBQsWIDg4GJGRkbX9yBodhhsiIqpTc+bMgVwuR0hIiPESUFVWrVqFpk2bom/fvhg5ciSGDh2K7t2727Qeb29vbNy4ETt27EBISAiWLl1q9vNpT09PHDhwAAUFBRgwYAB69OiBjz76yDiKM3nyZKxevRoJCQkIDQ3FiBEjcOnSJeP2SUlJKC0tRUREBGbMmIElS5ZYVJsl55+YmIgnn3wSL774Ijp27Ihp06ahsLDQpM9zzz2HkpISk4nTDxJBtPT3fBKh0WigUqmQl5cHDw8Pe5dDRGSRO3fuIC0tDUFBQXBycrJ3OdTAHT58GA899BCuXr0KHx8fe5djsfv9O7fm+7vhX6wlIiIiixQXFyMjIwP//Oc/MW7cuEYVbGyJl6WIiIgk4uOPP0aHDh2Ql5eHZcuW2bscu2G4ISIikoiYmBjodDqcOHECAQEB9i7HbhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiowQsMDMTq1auNy4Ig4PPPP6+y/5UrVyAIgsUPrKzr/VD94h2KiYio0VGr1WjatKlN9xkTE4Pc3FyT0NSyZUuo1Wp4eXnZ9FhUtxhuiIio0fH19a2X48jl8no7VkNTWlpqfFBoY8PLUkREVGfWrVuHgIAA6PV6k/ZRo0Zh8uTJAIDLly9j9OjR8PHxgZubG3r27Invvvvuvvu997LUsWPH0K1bNzg5OSEiIgInT5406a/T6fDcc88hKCgIzs7O6NChA9577z3j+oULF2LTpk34z3/+A0EQIAgCfvjhh0ovS/3444/o1asXlEol/Pz8MH/+fGi1WuP6hx56CNOnT8e8efPQrFkz+Pr6YuHChfc9n59//hmDBw+Gl5cXVCoVBgwYgF9++cWkT25uLv7617/Cx8cHTk5OCAsLw5dffmlcf/jwYQwYMAAuLi5o2rQphg4dilu3bgEwv6wHAF27djWpSxAEfPDBBxg9ejRcXV2xZMmSaj+3cklJSQgNDTV+Ji+//DIAYOrUqRgxYoRJX61WC19fXyQlJd33M6kNjtwQETVWogiUFtnn2A4ugCBU223s2LGYPn06Dh48iIcffhgAcOvWLezbtw9ffPEFAKCgoADR0dFYsmQJnJycsGnTJowcORIXLlxAq1atqj1GYWEhRowYgUGDBmHr1q1IS0vDjBkzTPro9Xq0aNECn376Kby8vHDkyBH89a9/hZ+fH8aNG4c5c+YgNTUVGo0GGzZsAAA0a9YMWVlZJvvJzMxEdHQ0YmJisHnzZpw/fx7Tpk2Dk5OTSVDYtGkTZs2ahZ9++glHjx5FTEwM+vXrh8GDB1d6Dvn5+Zg8eTLi4+MBACtWrEB0dDQuXboEd3d36PV6DBs2DPn5+di6dSvatm2Lc+fOQS6XAwBSUlLw8MMPY+rUqYiPj4dCocDBgweh0+mq/fwqeuONNxAXF4dVq1ZBLpdX+7kBQGJiImbNmoWlS5di2LBhyMvLw+HDhwEAzz//PPr37w+1Wg0/Pz8AwN69e1FQUGDcvi4w3BARNValRcDb/vY59v9lAY6u1XZr1qwZHn30Ufz73/82hpsdO3agWbNmxuUuXbqgS5cuxm2WLFmC3bt3Y8+ePcYRgPvZtm0bdDodkpKS4OLigtDQUFy9ehUvvPCCsY+DgwMWLVpkXA4KCsKRI0fw6aefYty4cXBzc4OzszOKi4vvexkqISEBLVu2xNq1ayEIAjp27IisrCz8/e9/x+uvvw6ZzHBBJDw8HG+88QYAoF27dli7di2+//77KsPNoEGDTJbXrVuHpk2b4scff8SIESPw3Xff4dixY0hNTUX79u0BAG3atDH2X7ZsGSIiIpCQkGBsCw0Nrfazu9fEiRMxdepUk7b7fW6A4b/X7NmzTQJlz549AQB9+/ZFhw4dsGXLFsybNw8AsGHDBowdOxZubm5W12cpXpYiIqI69fTTT2Pnzp0oLi4GYAgjEyZMMI46FBYWYt68eQgJCUGTJk3g5uaG8+fPIz093aL9p6amokuXLnBxcTG2RUZGmvX74IMPEBERAW9vb7i5ueGjjz6y+BgVjxUZGQmhwqhVv379UFBQgKtXrxrbwsPDTbbz8/NDdnZ2lfvNzs5GbGws2rdvD5VKBZVKhYKCAmN9KSkpaNGihTHY3Kt85Ka2IiIizNru97llZ2cjKyvrvsd+/vnnjaNh2dnZ+Oqrr8wClK1x5IaIqLFycDGMoNjr2BYaOXIk9Ho9vvrqK/Ts2RPJyclYuXKlcf3cuXOxb98+LF++HMHBwXB2dsaTTz6JkpISi/YvimK1fT799FO8+uqrWLFiBSIjI+Hu7o53330XP/30k8XnUX4s4Z7LceXHr9h+70RcQRDM5h1VFBMTg+vXr2P16tVo3bo1lEolIiMjjZ+Bs7Pzfeuqbr1MJjP7nEpLS836ubqajsZV97lVd1wAmDRpEubPn4+jR4/i6NGjCAwMRFRUVLXb1QbDDRFRYyUIFl0asjdnZ2c8/vjj2LZtG3777Te0b98ePXr0MK5PTk5GTEwMHnvsMQCGOThXrlyxeP8hISHYsmULbt++bfyy/d///mfSJzk5GX379sWLL75obLt8+bJJH0dHx2rnqISEhGDnzp0mIefIkSNwd3dHQECAxTXfKzk5GQkJCYiOjgYAZGRkICcnx7g+PDwcV69excWLFysdvQkPD8f3339vcgmpIm9vb6jVauOyRqNBWlqaRXXd73Nzd3dHYGAgvv/+ewwcOLDSfXh6emLMmDHYsGEDjh49iilTplR73NriZSkiIqpzTz/9NL766iskJSXhmWeeMVkXHByMXbt2ISUlBadOncLEiRPvO8pxr4kTJ0Imk+G5557DuXPnsHfvXixfvtzsGMePH8e+fftw8eJF/POf/8TPP/9s0icwMBCnT5/GhQsXkJOTU+nIxosvvoiMjAy88sorOH/+PP7zn//gjTfewKxZs4zzbWoiODgYW7ZsQWpqKn766Sc8/fTTJqMiAwYMQP/+/fHEE09g//79SEtLw9dff41vvvkGALBgwQL8/PPPePHFF3H69GmcP38eiYmJxoA0aNAgbNmyBcnJyTh79iwmT55svCxYXV3VfW4LFy7EihUrEB8fj0uXLuGXX37BmjVrTPo8//zz2LRpE1JTU42/kqtLDDdERFTnBg0ahGbNmuHChQuYOHGiybpVq1ahadOm6Nu3L0aOHImhQ4eie/fuFu/bzc0NX3zxBc6dO4du3brhtddewzvvvGPSJzY2Fo8//jjGjx+P3r1748aNGyajEQAwbdo0dOjQwTi/pPwXPxUFBARg7969OHbsGLp06YLY2Fg899xz+Mc//mHFp2EuKSkJt27dQrdu3fDss89i+vTpaN68uUmfnTt3omfPnnjqqacQEhKCefPmGUea2rdvj2+//RanTp1Cr169EBkZif/85z9QKAwXaBYsWID+/ftjxIgRiI6OxpgxY9C2bdtq67Lkc5s8eTJWr16NhIQEhIaGYsSIEbh06ZJJn0ceeQR+fn4YOnQo/P3rfhK8IFpysVJCNBoNVCoV8vLy4OHhYe9yiIgscufOHaSlpSEoKAhOTk72LofIKkVFRfD390dSUhIef/zxKvvd79+5Nd/fnHNDREREdUKv1+PatWtYsWIFVCoVRo0aVS/HZbghIiKiOpGeno6goCC0aNECGzduNF4mq2sMN0RERFQnAgMDLfqpvq1xQjERERFJCsMNEVEj8oD9BoQeMLb6981wQ0TUCJTf8baoyE4PyiSqB+V3ZLbkHjz3wzk3RESNgFwuR5MmTYzPJ3JxcTF7DABRY6bX63H9+nW4uLjUeuIxww0RUSNR/rTq+z2Akagxk8lkaNWqVa2DO8MNEVEjIQgC/Pz80Lx580ofDUDU2Dk6OtbqMRblGG6IiBoZuVxe6zkJRFLGCcVEREQkKQw3REREJCkMN0RERCQpDDdEREQkKQw3REREJCn8tRQREZGtiCKgKwUgGt5DvNtueFNhuZJ1lfaDeb/77aOm66w+9n22kykA7/awF4YbIiJ6sIkioL0DFOfffZUUAMUFZe/L2wvK2jUV1lXST6+19xnZn5svMOeC3Q7PcENERI2PKAKlRfcEi4qBQ1MhjJSv09wTRir0E3X2PqNaEgDjXX3L/grCPe8tWFdpv/J1uM+6e7Zz8azFudQeww0REZkQRRFavYhSnR6lWhHFOh1KdSJKtXqU6PQo0epRavxr6Fdc1laq00MvAgqZALlMuPtXLkAuCHDU34ajrtDw0hbCQVcEhbYQDtoCKLSFUJQWQK4thLy0EPLSAshKCiArLYBQcveF4nzDX1Fv+5N3dAeUboDSHXB0K3vvUfa+bJ1jWZtJP/e77x1dAKHsJos2Dxj3tlFlGG6I6MEmioBeZ7iUoC81zJfQ6yq81xpelb4vBXRa821FPUzmXIj6Cu9F6EUROr0eOp0OOp3e8F6vv/tepytrM/TTl7Xr9Xfb9Tod9HoROtGwXq/XQy+W9xMNy2UvUSzbRq+HKIrGNn1ZX1HUQy+KEMvXi3oIogjDV69Y4VVhWQAE6O/pAyggQolSuOI23IQ7cMNtuJb/xR3IBPE+/zGspxcFFApOKIQLiuCMIsEZtwUX3BacUSS44I7MGXdkLiiWuZT9dUWx3AWlchcUy91QKndBicIVWoUrdHJnw92fZTJDKJML94Q0GRQ6AfJiAYpSAfLbFdqN/XSQyQoM/6nL5qCIomE2ilhJG8TylrK2smXjehjaUMX297aV9ze+r9Cn4j5hcpyq9wljW9X7vLvLu/t0d1Jg3qMda/XftjYYbojINrQlQME14Patsi/8qsJBFYHAomBx77bW7UesZD+Cvv6f0SQreznU+5GrUWEgoS7pIDMGkUKUvUQn5MMZBaIz8kUnFIhO0OidUFDWVgBnFMLJ+L5ANKy7DaUNitYB0JS9yBaauysZboioAdPrgMLrQL4ayL8GaLIMf8uX868B+VlA0Q17V1ota74CS0U5tJCjFHLoUOG9aPirLWsvf6+FHNqybe6OdQB6yAz/b9Y4xmF4rwKCZHcAACAASURBVK/QBxAgCAIEmQyCIDO8F2QQZIZ2mSArW1f+3vBXJivvZ3hvaJNBJggQZHLIykYVBMHwVyaTQSaXGfvJy7eTyaAoX1fWXt6//K8gCHcvoxj/yippEwC5493LNhUv2ZRdtpE7OMNdEOBuwX8Hvd5wiUynF6EtG7W6uyxCp6uiXa+HVidW3l6+rKui3WR9Je3lxzPb/906BBgedGq8oCSU/3eGSZtQ3lZ+talsbktl29/bhorb32efFY9bvv7efaKy7Su2lR+3Qp3326er0r7xguGG6EElioZRlnx1heBieK/XGF5CvhqyomwIFs5tKIUD8mUeKIXC8KUvygzvRRlKRTlKRRlKxAphoMKrPDhUDBTmwUEBLWTm2xu3U0AHWdl2Ze9NtlOU7dvwXpA5QK5QQKZwgFyhhNxBAblCCQeFAkoHueGlkJW95FA6GN47GdvL/jrcfe9Y1t9RLoND+V+5od3QJpitk8s4f6IqMpkAR+Pnw4eFkmUYbogaOVEUUazVo6BYi6JiHQqKtbhTmIvS3CyIGkNokReq4ViUDeXtbDgXX4dbyXV4aHPgIFZ+Sab8skk5rSjDdTTBn2ITZItN8Wf5CxXei02RCzdYMz4iE2AaFBxkJqHBuK5CeDAsy+GmqCRo3LN9xe3u7ecoN4xaEJH02D3cJCQk4N1334VarUZoaChWr16NqKioKvtv27YNy5Ytw6VLl6BSqfDoo49i+fLl8PS078/OiCyl14soKtWhsFhb9tKhsMTwvqBYi6ISnUm7IbRoUVCsQ2lxERxvZ8P5znW4l+bAXZuDZroceOMWfHALPsIttBRy4S7ctrieG6I7ssWmuFYhsJQv35R5osDBC8VOzeCsVMJVKYerUgFXR4Xhr1KOUKUCvZUKuDga1rk4yuFkDBqVh5PyoKGQ8ybpRGR7dg0327dvx8yZM5GQkIB+/fph3bp1GDZsGM6dO4dWrVqZ9f/vf/+LSZMmYdWqVRg5ciQyMzMRGxuL559/Hrt377bDGRAZwsofN4twNjMP59Qa3CwoMYaVwpIKIabsfVGJ+f005NDBC3nwEW6ZvILLAkvzsuVmQoF5AVWM1BfCBbfkntA4eCLfwRtFyua44+yNUmcf6Nx8Ibr5QubuCydnF7gpFWiilCNAWRZaHOVwcVTAUcHwQUSNjyCKFX/IVb969+6N7t27IzEx0djWqVMnjBkzBnFxcWb9ly9fjsTERFy+fNnYtmbNGixbtgwZGRkWHVOj0UClUiEvLw8eHh61Pwl6oOj0ItJyCnA2U4MzmXmGQJOlQX5xVXckFdEM+caw0lwwjLD4li37yXLRXLiFZsiDHJbNa9HJHFHq4gudqw9Ed18IHv6Qq/zg2CQAMg8/wN0PcPc1TOYkIpIIa76/7TZyU1JSghMnTmD+/Pkm7UOGDMGRI0cq3aZv37547bXXsHfvXgwbNgzZ2dn47LPPMHz48CqPU1xcjOLiYuOyRsOf+pFltDo9frtuCDJnM/Nw/moOrl67BkVpPjxQBA+hCJ4oRLRwG00ditDGTYuWLqXwFnLhXpoDt5IcOBVfh9zSnxoLckMocfe9G1DcfQF3f5M2uXNTyHkDLyKiKtkt3OTk5ECn08HHx8ek3cfHB9euXat0m759+2Lbtm0YP3487ty5A61Wi1GjRmHNmjVVHicuLg6LFi2yae3UCIkiUHobuJN391WsKXufC21RHm7dvI68Wzko0txEaWEuZCUauIuF6C8UYTiK4CyUGGbZKqs4xu2yV2VcvICKoyrGvxWCi6sXIOOvQYiIasvuE4qFe/4fqCiKZm3lzp07h+nTp+P111/H0KFDoVarMXfuXMTGxmL9+vWVbrNgwQLMmjXLuKzRaNCyZUvbnQDVD73e8OyYO5oqAkolr3vX3edhdgoA3mUvoypuaCYqPSAoPQAnVdmrwnulB+DW3DS4uPkACkfbfh5ERFQlu4UbLy8vyOVys1Ga7Oxss9GccnFxcejXrx/mzp0LAAgPD4erqyuioqKwZMkS+Pn5mW2jVCqhVFb1f7Wp3ui0ZWEj14KAUmG5uEIbaj89TAcZNKILNKIL8mH4q4ErNKILihWucHb3hKqpFzy9vOHv44Pm3s0hc25yN8AoPSBwdIWIqEGzW7hxdHREjx49sH//fjz22GPG9v3792P06NGVblNUVASFwrRkudzwRWPHedF0OxfIOglkngCuXzAPMOVP4rUFuWOFEZMKoyVOKpQ4eCC7VImMQgUuFyhw/paAi7kCckVXY4gpKrtVe1MXB4QFqAwvfxX6BKjQsplzlaOGRETUeNj1stSsWbPw7LPPIiIiApGRkfjwww+Rnp6O2NhYAIZLSpmZmdi8eTMAYOTIkZg2bRoSExONl6VmzpyJXr16wd/f356n8uAovQNcO2MIMpkngKxfgBu/Wb69g2vVl3NMQkv5chPT9Q5OAADNnVL8WjbR92xWHs5cykNaTiEqy7hebo4IC1Chc4AKof4qdG6hgr/KiUGGiEii7Bpuxo8fjxs3bmDx4sVQq9UICwvD3r170bp1awCAWq1Genq6sX9MTAzy8/Oxdu1azJ49G02aNMGgQYPwzjvv2OsUpE2vM4zEVAwyf/5a+dyVpoFAQA/At7Nh8mxlIUXpDsitf1RgblGJ4RdLWZk4k5mHXzPzcOVGUaV9fT2cEBbgYRyR6dxChebuSgYZIqIHiF3vc2MPvM9NFUQRyE2/G2IyfwGyUoDSQvO+rt6GIOPfvexvN8DVNneIzikoxtnMPPyapcGZq4ZRmau3Kv8JUkATZ4QFeBhGZMrCjLc751cREUlRo7jPDdlZYU5ZgPnl7shMZU91dnQzhBf/boYgE9AdULWs+LjZGsvW3Cm7EZ7hhni/ZuVBnXen0r6tPV0Q5l82RybAA6H+KjRz5S+QiIjIHMPNg6C4AFCfMg0yuenm/WQOgE9oWYgpCzJe7Wt97xVRFKHOu2O8pHQ2yxBmrucXm/UVBCDIy9VwSSlAhdCyIKNytv5yFhERPZgYbqRGV2qYF2MMMr8A188DYiW39vdsZxpkfMKME3ZrShRFXL112/hogrNZhkm/NwtLzPrKBCC4uRvC/A2XlToHqBDi7wE3Jf9ZEhFRzfFbpDHT64Gbv5sGmWunAW0ll3bc/Q0BpjzI+HczTPS1kWKtDmsP/IYt//sDuUXmjxtQyAS083FHmL+H8SfYnfzc4eLIf4JERGRb/GZpTDRq0yCT9YvhPjL3clLdnewb0N3w3sP8Boe2cuZqHubsOIULf+YDABzlMnTwdTf51VIHX3c4OfDmd0REVPcYbhqqO3l3b4yXWfbrpfws835yJeDX5W6QCegBNGtjkwm/1SnR6rHmwCUk/HAZOr0IT1dHLBodiiEhvnBUyOr8+ERERJVhuGkISu8Af56tEGROADcumfcTZIB3p7IQUxZkmofU6N4xtXU20zBac/6aYbRmRLgfFo8O4y+YiIjI7hhu6pteB+RcNA0yf/4K6M3nqaBJ6wrzZHoAvuGA0q3+a66gRKvH2oO/4f2Dv0GnF9HM1RFLxoQhunPdXfYiIiKyBsNNXRJFIC/jbojJ/AVQp1T+nCUXL9Mg498NcPWq/5rv49esPMzZcRqpag0AYHhnPyweHQpPN944j4iIGg6GG1sqvHH37r7ld/otvG7ez8EV8O96N8z4dweatKqXeTI1UarT4/2Dv2Htgd+gLRutWTw6FCPC+TwvIiJqeBhubOXaGeCDv5i3yxSmN8bz7w54d6j1jfHqy7ksDebsOIVzZaM1w8J88eaYMHhxtIaIiBoohhtb8epg+OVSk5amQca3c61vjGcPpTo9Eg5expoDl6DVi2jq4oBFo8MwMtyPD6EkIqIGjeHGVhSOwLzf7T7h1xbOX9Ng9qen8GuWYbRmaKgPlozpzIdSEhFRo8BwY0uNPNiU6vT44IfLiD9wCaU6EU1cHLBoVChGdfHnaA0RETUaDDcEALhwLR9zdpzCmUzDHY8Hh/jgrcfC0Ny98V1SIyKiBxvDzQNOq9Nj3aHf8d53l1Ci00PlbBitGd2VozVERNQ4Mdw8wC79mY/ZO07h9FXDaM0jnZrj7cc6o7kHR2uIiKjxYrh5AGl1enyY/DtW7zeM1ng4KbBwVCge6xbA0RoiImr0GG4eMJf+zMecz07jVEYuAGBQx+aIe7wzfDhaQ0REEsFw84DQ6UV8lPw7Vu6/iBKtHu5OCrwxMhRPdOdoDRERSQvDzQPgt+wCzNlxCillozUDO3gj7vFw+Ko4WkNERNLDcCNhOr2I9f/9Hcu/LRutUSrwz5EhGNujBUdriIhIshhuJOry9QLM3XEKv6QbRmsGtPfG0ic6w0/lbOfKiIiI6hbDjcTo9CI2HE7Du/suoLhstOYfIzphXERLjtYQEdEDgeFGQn6/XoB5n53G8T9uAQCi2nnhnSfC4d+EozVERPTgYLiRgHtHa9yUCrw2vBMm9ORoDRERPXgYbhq5KzmFmPvZKfx8xTBa85dgL7zzZDgCOFpDREQPKIabRkqvF7HxyBUs23ced0r1cHWU47XhIXiqF0driIjowcZw0wj9caMQcz87jWNpNwEA/YI9sfTxcLRs5mLnyoiIiOyP4aYR0etFbD56Be98cwG3S3VwcZTj/6I74enerThaQ0REVIbhppFIv1GEuZ+dwk9lozWRbTyx7EmO1hAREd2L4aaB0+tFbP3pDyz9+jyKSgyjNQuGdcTTvVtDJuNoDRER0b0YbhqwjJtFmPfZaRz9/QYAoHdQM7z7ZBe08uRoDRERUVUYbhogvV7EtmPpiNubiqISHZwd5Jg/rCOe7cPRGiIiouow3DQwV28V4e87T+Pwb4bRml6BzfDu2HC09nS1c2VERESNA8NNAyGKIv59LB1vf5WKwhIdnBxk+PujHTE5MpCjNURERFZguGkAMnNvY/7O00i+lAMA6BnYFO8+2QWBXhytISIishbDjR2JoohPfs7AW1+loqBYCycHGeYO7YgpfTlaQ0REVFMMN3aSlXsb83edwaGL1wEAEa2b4t2xXRDE0RoiIqJaYbipZ6Io4tPjGVjyZSryi7VQKmSYO7QDpvQLgpyjNURERLXGcFOP1Hm3MX/nGfxYNlrTvVUTvDu2C9p6u9m5MiIiIulguKkHoihix4mrePPLc8i/o4WjQoY5Q9rjub+04WgNERGRjTHc1LFreXewYNdpHLxgGK3p2rIJlo/tguDmHK0hIiKqCww3dUQURez8JROLvvjVOFoza3B7TIviaA0REVFdYripA39q7mDBrjM4cD4bANClZROsGBuO4Obudq6MiIhI+hhubEgURew+mYmFe36F5o4WjnIZZg5uh79GtYFCLrN3eURERA8Ehhsbydbcwf/tPoPvUg2jNeEtVFg+tgva+3C0hoiIqD4x3NjIzaIS/HjxOhzkAmY+0h5/68/RGiIiIntguLGRjr4eiHs8HJ0DVOjgy9EaIiIie2G4saEne7SwdwlEREQPPF43ISIiIklhuCEiIiJJYbghIiIiSWG4ISIiIklhuCEiIiJJYbghIiIiSWG4ISIiIklhuCEiIiJJYbghIiIiSWG4ISIiIklhuCEiIiJJYbghIiIiSWG4ISIiIkmxe7hJSEhAUFAQnJyc0KNHDyQnJ9+3f3FxMV577TW0bt0aSqUSbdu2RVJSUj1VS0RERA2dwp4H3759O2bOnImEhAT069cP69atw7Bhw3Du3Dm0atWq0m3GjRuHP//8E+vXr0dwcDCys7Oh1WrruXIiIiJqqARRFEV7Hbx3797o3r07EhMTjW2dOnXCmDFjEBcXZ9b/m2++wYQJE/D777+jWbNmNTqmRqOBSqVCXl4ePDw8alw7ERER1R9rvr/tdlmqpKQEJ06cwJAhQ0zahwwZgiNHjlS6zZ49exAREYFly5YhICAA7du3x5w5c3D79u0qj1NcXAyNRmPyIiIiIumy22WpnJwc6HQ6+Pj4mLT7+Pjg2rVrlW7z+++/47///S+cnJywe/du5OTk4MUXX8TNmzernHcTFxeHRYsW2bx+IiIiapjsPqFYEASTZVEUzdrK6fV6CIKAbdu2oVevXoiOjsbKlSuxcePGKkdvFixYgLy8POMrIyPD5udAREREDYfdRm68vLwgl8vNRmmys7PNRnPK+fn5ISAgACqVytjWqVMniKKIq1evol27dmbbKJVKKJVK2xZPREREDZbdRm4cHR3Ro0cP7N+/36R9//796Nu3b6Xb9OvXD1lZWSgoKDC2Xbx4ETKZDC1atKjTeomIiKhxsOtlqVmzZuFf//oXkpKSkJqaildffRXp6emIjY0FYLikNGnSJGP/iRMnwtPTE1OmTMG5c+dw6NAhzJ07F1OnToWzs7O9ToOIiIgaELve52b8+PG4ceMGFi9eDLVajbCwMOzduxetW7cGAKjVaqSnpxv7u7m5Yf/+/XjllVcQEREBT09PjBs3DkuWLLHXKRAREVEDY9f73NgD73NDRETU+NTpfW7S0tJqXBgRERFRXbM63AQHB2PgwIHYunUr7ty5Uxc1EREREdWY1eHm1KlT6NatG2bPng1fX1/87W9/w7Fjx+qiNiIiIiKrWR1uwsLCsHLlSmRmZmLDhg24du0a/vKXvyA0NBQrV67E9evX66JOIiIiIovU+KfgCoUCjz32GD799FO88847uHz5MubMmYMWLVpg0qRJUKvVtqyTiIiIyCI1DjfHjx/Hiy++CD8/P6xcuRJz5szB5cuXceDAAWRmZmL06NG2rJOIiIjIIlbf52blypXYsGEDLly4gOjoaGzevBnR0dGQyQw5KSgoCOvWrUPHjh1tXiwRERFRdawON4mJiZg6dSqmTJkCX1/fSvu0atUK69evr3VxRERERNbiTfyIiIiowavTm/ht2LABO3bsMGvfsWMHNm3aZO3uiIiIiGzK6nCzdOlSeHl5mbU3b94cb7/9tk2KIiIiIqopq8PNH3/8gaCgILP21q1bmzzkkoiIiMgerA43zZs3x+nTp83aT506BU9PT5sURURERFRTVoebCRMmYPr06Th48CB0Oh10Oh0OHDiAGTNmYMKECXVRIxEREZHFrP4p+JIlS/DHH3/g4YcfhkJh2Fyv12PSpEmcc0NERER2V+Ofgl+8eBGnTp2Cs7MzOnfujNatW9u6tjrBn4ITERE1PtZ8f1s9clOuffv2aN++fU03JyIiIqoTNQo3V69exZ49e5Ceno6SkhKTdStXrrRJYUREREQ1YXW4+f777zFq1CgEBQXhwoULCAsLw5UrVyCKIrp3714XNRIRERFZzOpfSy1YsACzZ8/G2bNn4eTkhJ07dyIjIwMDBgzA2LFj66JGIiIiIotZHW5SU1MxefJkAIBCocDt27fh5uaGxYsX45133rF5gURERETWsDrcuLq6ori4GADg7++Py5cvG9fl5OTYrjIiIiKiGrB6zk2fPn1w+PBhhISEYPjw4Zg9ezbOnDmDXbt2oU+fPnVRIxEREZHFrA43K1euREFBAQBg4cKFKCgowPbt2xEcHIxVq1bZvEAiIiIia1gVbnQ6HTIyMhAeHg4AcHFxQUJCQp0URkRERFQTVs25kcvlGDp0KHJzc+uqHiIiIqJasXpCcefOnfH777/XRS1EREREtWZ1uHnrrbcwZ84cfPnll1Cr1dBoNCYvIiIiInuy+sGZMtndPCQIgvG9KIoQBAE6nc521dUBPjiTiIio8anTB2cePHiwxoURERER1TWrw82AAQPqog4iIiIim7A63Bw6dOi+6/v371/jYoiIiIhqy+pw89BDD5m1VZx709Dn3BAREZG0Wf1rqVu3bpm8srOz8c0336Bnz5749ttv66JGIiIiIotZPXKjUqnM2gYPHgylUolXX30VJ06csElhRERERDVh9chNVby9vXHhwgVb7Y6IiIioRqweuTl9+rTJsiiKUKvVWLp0Kbp06WKzwoiIiIhqwupw07VrVwiCgHvv/denTx8kJSXZrDAiIiKimrA63KSlpZksy2QyeHt7w8nJyWZFEREREdWU1eGmdevWdVEHERERkU1YPaF4+vTpiI+PN2tfu3YtZs6caZOiiIiIiGrK6nCzc+dO9OvXz6y9b9+++Oyzz2xSFBEREVFNWR1ubty4Uem9bjw8PJCTk2OTooiIiIhqyupwExwcjG+++cas/euvv0abNm1sUhQRERFRTVk9oXjWrFl4+eWXcf36dQwaNAgA8P3332PFihVYvXq1zQskIiIisobV4Wbq1KkoLi7GW2+9hTfffBMAEBgYiMTEREyaNMnmBRIRERFZQxDvvRufFa5fvw5nZ2e4ubnZsqY6pdFooFKpkJeXBw8PD3uXQ0RERBaw5vu7Rjfx02q1aNeuHby9vY3tly5dgoODAwIDA60umIiIiMhWrJ5QHBMTgyNHjpi1//TTT4iJibFFTUREREQ1ZnW4OXnyZKX3uenTpw9SUlJsUhQRERFRTVkdbgRBQH5+vll7Xl4edDqdTYoiIiIiqimrw01UVBTi4uJMgoxOp0NcXBz+8pe/2LQ4IiIiImtZPaF42bJl6N+/Pzp06ICoqCgAQHJyMvLy8nDw4EGbF0hERERkDatHbkJCQnD69GmMGzcO2dnZyM/Px6RJk3Dx4kVotdq6qJGIiIjIYrW6zw0A5ObmYtu2bUhKSkJKSkqDn3fD+9wQERE1PtZ8f1s9clPuwIEDeOaZZ+Dv74+1a9di2LBhOH78eE13R0RERGQTVs25uXr1KjZu3IikpCQUFhZi3LhxKC0txc6dOxESElJXNRIRERFZzOKRm+joaISEhODcuXNYs2YNsrKysGbNmrqsjYiIiMhqFo/cfPvtt5g+fTpeeOEFtGvXri5rIiIiIqoxi0dukpOTkZ+fj4iICPTu3Rtr167F9evX67I2IiIiIqtZHG4iIyPx0UcfQa1W429/+xs++eQTBAQEQK/XY//+/ZXetZiIiIiovtXqp+AXLlzA+vXrsWXLFuTm5mLw4MHYs2ePLeuzOf4UnIiIqPGpl5+CA0CHDh2wbNkyXL16FR9//HFtdkVERERkE7UKN+XkcjnGjBlTo1GbhIQEBAUFwcnJCT169EBycrJF2x0+fBgKhQJdu3a1+phEREQkXTYJNzW1fft2zJw5E6+99hpOnjyJqKgoDBs2DOnp6ffdLi8vD5MmTcLDDz9cT5USERFRY1Hrxy/URu/evdG9e3ckJiYa2zp16oQxY8YgLi6uyu0mTJiAdu3aQS6X4/PPP0dKSorFx+ScGyIiosan3ubc1EZJSQlOnDiBIUOGmLQPGTIER44cqXK7DRs24PLly3jjjTcsOk5xcTE0Go3Ji4iIiKTLbuEmJycHOp0OPj4+Ju0+Pj64du1apdtcunQJ8+fPx7Zt26BQWHb/wbi4OKhUKuOrZcuWta6diIiIGi67zrkBAEEQTJZFUTRrAwCdToeJEydi0aJFaN++vcX7X7BgAfLy8oyvjIyMWtdMREREDZdVD860JS8vL8jlcrNRmuzsbLPRHADIz8/H8ePHcfLkSbz88ssAAL1eD1EUoVAo8O2332LQoEFm2ymVSiiVyro5CSIiImpw7DZy4+joiB49emD//v0m7fv370ffvn3N+nt4eODMmTNISUkxvmJjY9GhQwekpKSgd+/e9VU6ERERNWB2G7kBgFmzZuHZZ59FREQEIiMj8eGHHyI9PR2xsbEADJeUMjMzsXnzZshkMoSFhZls37x5czg5OZm1ExER0YPLruFm/PjxuHHjBhYvXgy1Wo2wsDDs3bsXrVu3BgCo1epq73lDREREVJFd73NjD7zPDRERUePTKO5zQ0RERFQXGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFIYboiIiEhSGG6IiIhIUhhuiIiISFLsHm4SEhIQFBQEJycn9OjRA8nJyVX23bVrFwYPHgxvb294eHggMjIS+/btq8dqiYiIqKGza7jZvn07Zs6ciddeew0nT55EVFQUhg0bhvT09Er7Hzp0CIMHD8bevXtx4sQJDBw4ECNHjsTJkyfruXIiIiJqqARRFEV7Hbx3797o3r07EhMTjW2dOnXCmDFjEBcXZ9E+QkNDMX78eLz++usW9ddoNFCpVMjLy4OHh0eN6iYiIqL6Zc33t91GbkpKSnDixAkMGTLEpH3IkCE4cuSIRfvQ6/XIz89Hs2bNquxTXFwMjUZj8iIiIiLpslu4ycnJgU6ng4+Pj0m7j48Prl27ZtE+VqxYgcLCQowbN67KPnFxcVCpVMZXy5Yta1U3ERERNWx2n1AsCILJsiiKZm2V+fjjj7Fw4UJs374dzZs3r7LfggULkJeXZ3xlZGTUumYiIiJquBT2OrCXlxfkcrnZKE12drbZaM69tm/fjueeew47duzAI488ct++SqUSSqWy1vUSERFRmM+VPwAAD1NJREFU42C3kRtHR0f06NED+/fvN2nfv38/+vbtW+V2H3/8MWJiYvDvf/8bw4cPr+syiYiIqJGx28gNAMyaNQvPPvssIiIiEBkZiQ8//BDp6emIjY0FYLiklJmZic2bNwMwBJtJkybhvffeQ58+fYyjPs7OzlCpVHY7DyIiImo47Bpuxo8fjxs3bmDx4sVQq9UICwvD3r170bp1awCAWq02uefNunXroNVq8dJLL+Gll14ytk+ePBkbN26s7/KJiIioAbLrfW7sgfe5ISIianwaxX1uiIiIiOoCww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDDREREUkKww0RERFJCsMNERERSQrDzf+3d+cxUZwNGMCflWM5ipSjHBsRqVIREaqsJcsRU2koqzFiqdoG7RrTmC2HHCGxlTagaaT/tE1NdRsUaE1tMMRCMRUBW0utRwQEJZRSG42YFkLRytUUK7zfH8bNt9/y2UOYF4fnl0yy++4M87wDCU9mZ3eIiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFWkl5v9+/cjJCQELi4uiI6OxunTpx+4fmNjI6Kjo+Hi4oInn3wSH330kUJJiYiI6FEgtdwcOXIEOTk5KCgoQGtrKxISEmA0GtHd3T3h+teuXcOqVauQkJCA1tZW7Ny5E9u3b8fRo0cVTk5ERETTlUYIIWTtPCYmBsuWLYPFYrGOLVq0CCkpKSguLrZbf8eOHaipqUFnZ6d1zGw249KlSzh37tzf2ufg4CA8PT0xMDCA2bNnP/wkiIiIaMr9k//f0s7c3LlzBy0tLUhKSrIZT0pKwtmzZyfc5ty5c3brP//882hubsaff/45ZVmJiIjo0eEoa8f9/f0YGxuDv7+/zbi/vz96e3sn3Ka3t3fC9e/evYv+/n4EBgbabTM6OorR0VHr84GBAQD3GiARERE9Gu7/3/47bzhJKzf3aTQam+dCCLuxv1p/ovH7iouLsWvXLrvxoKCgfxqViIiIJBsaGoKnp+cD15FWbnx9feHg4GB3lqavr8/u7Mx9AQEBE67v6OgIHx+fCbd54403kJeXZ30+Pj6OW7duwcfH54El6t8YHBxEUFAQbty4MSOv55np8wd4DGb6/AEeA85/Zs8fmLpjIITA0NAQdDrdX64rrdw4OzsjOjoaDQ0NWLdunXW8oaEBa9eunXAbg8GAY8eO2YzV19dDr9fDyclpwm20Wi20Wq3N2OOPP/6Q6R9s9uzZM/aPGuD8AR6DmT5/gMeA85/Z8wem5hj81Rmb+6R+FDwvLw8HDx5EWVkZOjs7kZubi+7ubpjNZgD3zrq88sor1vXNZjOuX7+OvLw8dHZ2oqysDKWlpcjPz5c1BSIiIppmpF5zs3HjRty8eRO7d+9GT08PIiIicPz4cQQHBwMAenp6bL7zJiQkBMePH0dubi727dsHnU6HvXv3IjU1VdYUiIiIaJqRfkFxeno60tPTJ3zt448/thtbsWIFLl68OMWp/h2tVovCwkK7t8Fmipk+f4DHYKbPH+Ax4Pxn9vyB6XEMpH6JHxEREdFkk35vKSIiIqLJxHJDREREqsJyQ0RERKrCckNERESqwnIzSfbv34+QkBC4uLggOjoap0+flh1JMd9++y3WrFkDnU4HjUaD6upq2ZEUVVxcjOXLl8PDwwN+fn5ISUlBV1eX7FiKslgsiIyMtH5pl8FgQG1trexY0hQXF0Oj0SAnJ0d2FMUUFRVBo9HYLAEBAbJjKernn3/Gpk2b4OPjAzc3Nzz99NNoaWmRHUsx8+bNs/sb0Gg0yMjIUDwLy80kOHLkCHJyclBQUIDW1lYkJCTAaDTafEePmo2MjCAqKgoffvih7ChSNDY2IiMjA+fPn0dDQwPu3r2LpKQkjIyMyI6mmDlz5uCdd95Bc3MzmpubsXLlSqxduxYdHR2yoymuqakJJSUliIyMlB1FcYsXL0ZPT491aW9vlx1JMb/99hvi4uLg5OSE2tpafP/993j33Xen/Bvxp5Ompiab339DQwMAYP369cqHEfTQnnnmGWE2m23GwsLCxOuvvy4pkTwARFVVlewYUvX19QkAorGxUXYUqby8vMTBgwdlx1DU0NCQCA0NFQ0NDWLFihUiOztbdiTFFBYWiqioKNkxpNmxY4eIj4+XHWNayc7OFvPnzxfj4+OK75tnbh7SnTt30NLSgqSkJJvxpKQknD17VlIqkmlgYAAA4O3tLTmJHGNjY6ioqMDIyAgMBoPsOIrKyMjA6tWr8dxzz8mOIsWVK1eg0+kQEhKCl156CVevXpUdSTE1NTXQ6/VYv349/Pz8sHTpUhw4cEB2LGnu3LmDTz/9FFu3bp30m1T/HSw3D6m/vx9jY2N2dzL39/e3u4M5qZ8QAnl5eYiPj0dERITsOIpqb2/HY489Bq1WC7PZjKqqKoSHh8uOpZiKigpcvHgRxcXFsqNIERMTg0OHDqGurg4HDhxAb28vYmNjcfPmTdnRFHH16lVYLBaEhoairq4OZrMZ27dvx6FDh2RHk6K6uhq3b9/Gli1bpOxf+u0X1OJ/m6kQQkpbJbkyMzNx+fJlfPfdd7KjKG7hwoVoa2vD7du3cfToUZhMJjQ2Ns6IgnPjxg1kZ2ejvr4eLi4usuNIYTQarY+XLFkCg8GA+fPn45NPPkFeXp7EZMoYHx+HXq/Hnj17AABLly5FR0cHLBaLzQ2gZ4rS0lIYjUbodDop++eZm4fk6+sLBwcHu7M0fX19dmdzSN2ysrJQU1ODU6dOYc6cObLjKM7Z2RkLFiyAXq9HcXExoqKi8MEHH8iOpYiWlhb09fUhOjoajo6OcHR0RGNjI/bu3QtHR0eMjY3Jjqg4d3d3LFmyBFeuXJEdRRGBgYF2RX7RokUz5oMl/+369es4efIkXn31VWkZWG4ekrOzM6Kjo61Xhd/X0NCA2NhYSalISUIIZGZm4vPPP8fXX3+NkJAQ2ZGmBSEERkdHZcdQRGJiItrb29HW1mZd9Ho90tLS0NbWBgcHB9kRFTc6OorOzk4EBgbKjqKIuLg4u6+A+PHHHxEcHCwpkTzl5eXw8/PD6tWrpWXg21KTIC8vD5s3b4Zer4fBYEBJSQm6u7thNptlR1PE8PAwfvrpJ+vza9euoa2tDd7e3pg7d67EZMrIyMjAZ599hi+++AIeHh7Ws3ienp5wdXWVnE4ZO3fuhNFoRFBQEIaGhlBRUYFvvvkGJ06ckB1NER4eHnbXWLm7u8PHx2fGXHuVn5+PNWvWYO7cuejr68Pbb7+NwcFBmEwm2dEUkZubi9jYWOzZswcbNmzAhQsXUFJSgpKSEtnRFDU+Po7y8nKYTCY4OkqsGIp/Pkul9u3bJ4KDg4Wzs7NYtmzZjPoY8KlTpwQAu8VkMsmOpoiJ5g5AlJeXy46mmK1bt1r//p944gmRmJgo6uvrZceSaqZ9FHzjxo0iMDBQODk5CZ1OJ1544QXR0dEhO5aijh07JiIiIoRWqxVhYWGipKREdiTF1dXVCQCiq6tLag6NEELIqVVEREREk4/X3BAREZGqsNwQERGRqrDcEBERkaqw3BAREZGqsNwQERGRqrDcEBERkaqw3BAREZGqsNwQEeHezW+rq6tlxyCiScByQ0TSbdmyBRqNxm5JTk6WHY2IHkG8txQRTQvJyckoLy+3GdNqtZLSENGjjGduiGha0Gq1CAgIsFm8vLwA3HvLyGKxwGg0wtXVFSEhIaisrLTZvr29HStXroSrqyt8fHywbds2DA8P26xTVlaGxYsXQ6vVIjAwEJmZmTav9/f3Y926dXBzc0NoaChqamqmdtJENCVYbojokfDWW28hNTUVly5dwqZNm/Dyyy+js7MTAPD7778jOTkZXl5eaGpqQmVlJU6ePGlTXiwWCzIyMrBt2za0t7ejpqYGCxYssNnHrl27sGHDBly+fBmrVq1CWloabt26peg8iWgSSL1tJxGREMJkMgkHBwfh7u5us+zevVsIce/O62az2WabmJgY8dprrwkhhCgpKRFeXl5ieHjY+vqXX34pZs2aJXp7e4UQQuh0OlFQUPB/MwAQb775pvX58PCw0Gg0ora2dtLmSUTK4DU3RDQtPPvss7BYLDZj3t7e1scGg8HmNYPBgLa2NgBAZ2cnoqKi4O7ubn09Li4O4+Pj6OrqgkajwS+//ILExMQHZoiMjLQ+dnd3h4eHB/r6+v71nIhIDpYbIpoW3N3d7d4m+isajQYAIISwPp5oHVdX17/185ycnOy2HR8f/0eZiEg+XnNDRI+E8+fP2z0PCwsDAISHh6OtrQ0jIyPW18+cOYNZs2bhqaeegoeHB+bNm4evvvpK0cxEJAfP3BDRtDA6Oore3l6bMUdHR/j6+gIAKisrodfrER8fj8OHD+PChQsoLS0FAKSlpaGwsBAmkwlFRUX49ddfkZWVhc2bN8Pf3x8AUFRUBLPZDD8/PxiNRgwNDeHMmTPIyspSdqJENOVYbohoWjhx4gQCAwNtxhYuXIgffvgBwL1PMlVUVCA9PR0BAQE4fPgwwsPDAQBubm6oq6tDdnY2li9fDjc3N6SmpuK9996z/iyTyYQ//vgD77//PvLz8+Hr64sXX3xRuQkSkWI0QgghOwQR0YNoNBpUVVUhJSVFdhQiegTwmhsiIiJSFZYbIiIiUhVec0NE0x7fPSeif4JnboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFVYboiIiEhVWG6IiIhIVVhuiIiISFX+A5xHJugH8C3DAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(train_a, label='train accuracy')\n",
    "plt.plot(val_a, label='validation accuracy')\n",
    "\n",
    "plt.title('Training history')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()\n",
    "plt.ylim([0, 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbT4YzHFh1s7"
   },
   "source": [
    "Accuracy of Pos/Neg on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 28664,
     "status": "ok",
     "timestamp": 1695329081374,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L1ZzFERMAQk9",
    "outputId": "56411273-fc8c-476f-f660-21239bd321ad",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.8173570019723866\n",
      "F1-Macro:  0.7958848634866261\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "test_acc, _, test_f1 = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "\n",
    "print(\"Accuracy: \",test_acc.item())\n",
    "print(\"F1-Macro: \",test_f1.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695329212064,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "d7DaFLMiAWps",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader):\n",
    "  model = model.eval()\n",
    "\n",
    "  review = []\n",
    "  predictions = []\n",
    "\n",
    "  prediction_probs = []\n",
    "  real_values = []\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "\n",
    "      reviews = d[\"review\"]\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      )\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "\n",
    "      probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "      review.extend(reviews)\n",
    "      predictions.extend(preds)\n",
    "\n",
    "      prediction_probs.extend(probs)\n",
    "      real_values.extend(sentiments)\n",
    "\n",
    "\n",
    "\n",
    "  predictions = torch.stack(predictions).cpu()\n",
    "  prediction_probs = torch.stack(prediction_probs).cpu()\n",
    "  real_values = torch.stack(real_values).cpu()\n",
    "\n",
    "  return review, predictions, prediction_probs, real_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "executionInfo": {
     "elapsed": 27254,
     "status": "ok",
     "timestamp": 1695329239311,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "kZhYtki1AYx8",
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(\n",
    "  model,\n",
    "  test_data_loader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695329239312,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "mO1FHn0GAbCU",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix\n",
    "class_names = ['negative', 'positive']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PnBSQFJuh_As"
   },
   "source": [
    "Pos/Neg Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695329239312,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "J9-01snpAcM8",
    "outputId": "4d6f7fff-9a14-49bf-c847-d2e8ba79f5f7",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative       0.84      0.87      0.86      1603\n",
      "    positive       0.77      0.72      0.74       932\n",
      "\n",
      "    accuracy                           0.82      2535\n",
      "   macro avg       0.81      0.80      0.80      2535\n",
      "weighted avg       0.82      0.82      0.82      2535\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_test, y_pred, target_names=class_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
