{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695324378707,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "F7toc08bpdQ1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from transformers import T5Config, T5Tokenizer, T5EncoderModel\n",
    "\n",
    "from transformers import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score\n",
    "import textwrap\n",
    "import math\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1695323483676,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "MLk4JWizs4S9",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/train.csv')\n",
    "df_val = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/val.csv')\n",
    "df_test = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "YWiFI0a9QqEa",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_senti = pd.read_excel('/home/m_nsu/ICLR/Datasets/GoEmotion/sentiwords.xlsx')\n",
    "conditions = [\n",
    "    (df_senti['PosScore'] > df_senti['NegScore']),\n",
    "    (df_senti['PosScore'] < df_senti['NegScore']),\n",
    "    (df_senti['PosScore'] == df_senti['NegScore'])\n",
    "    ]\n",
    "\n",
    "values = ['Positive','Negative','Neutral']\n",
    "\n",
    "df_senti = df_senti[['PosScore','NegScore','Word','Definition']]\n",
    "df_senti['Sentiment'] = np.select(conditions, values)\n",
    "df_senti = df_senti.dropna(axis=0)\n",
    "df_senti.drop(columns=['PosScore', 'NegScore'], inplace=True)\n",
    "df_senti = df_senti[['Word', 'Sentiment', 'Definition']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 48,
     "status": "ok",
     "timestamp": 1695323754433,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "vX1DfZpN1W01",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train.dropna(inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "executionInfo": {
     "elapsed": 48,
     "status": "ok",
     "timestamp": 1695323754435,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "q_pTZIR8ltaJ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train.rename(columns={'Text': 'review'}, inplace=True)\n",
    "df_val.rename(columns={'Text': 'review'}, inplace=True)\n",
    "df_test.rename(columns={'Text': 'review'}, inplace=True)\n",
    "\n",
    "\n",
    "df_train.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)\n",
    "df_val.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)\n",
    "df_test.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 49,
     "status": "ok",
     "timestamp": 1695323754437,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e0NG9J-oq_wW",
    "outputId": "f1bb3869-364e-4c7a-8c3f-a62ce4b182c7",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>review</th>\n",
       "      <th>Emotions</th>\n",
       "      <th>sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>my favorite food is anything i did not have to...</td>\n",
       "      <td>neutral</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>now if he does off himself everyone will think...</td>\n",
       "      <td>neutral</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>why the fuck is bayless isoing</td>\n",
       "      <td>anger</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>to make her feel threatened</td>\n",
       "      <td>fear</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>dirty southern wankers</td>\n",
       "      <td>annoyance</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              review   Emotions  sentiment\n",
       "0  my favorite food is anything i did not have to...    neutral          3\n",
       "1  now if he does off himself everyone will think...    neutral          3\n",
       "2                     why the fuck is bayless isoing      anger          1\n",
       "3                        to make her feel threatened       fear          1\n",
       "4                             dirty southern wankers  annoyance          1"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 45,
     "status": "ok",
     "timestamp": 1695323754437,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "RN48HyYaC2VD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def filter_rows_by_values(df, col, values):\n",
    "    return df[~df[col].isin(values)]\n",
    "\n",
    "df_train = filter_rows_by_values(df_train,'sentiment',[2,3])\n",
    "df_test = filter_rows_by_values(df_test,'sentiment',[2,3])\n",
    "df_val = filter_rows_by_values(df_val,'sentiment',[2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 45,
     "status": "ok",
     "timestamp": 1695323754438,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Qm9XHUHuuSXq",
    "outputId": "6e7765c0-f186-4b8c-db5a-d954fbb3dc9a",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train['emotion'], map = pd.factorize(df_train['Emotions'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754438,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "TCTYc5frxZfx",
    "tags": []
   },
   "outputs": [],
   "source": [
    "emotion_map = dict(zip(map, range(len(map))))\n",
    "map_emotion = {v: k for k, v in emotion_map.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754439,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "_xxZFLJwxfpy",
    "outputId": "914e4d70-43c6-462e-dc72-954759c555eb",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'gratitude'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emotion_map['gratitude']  # Get Encoded Label from Categorical Label\n",
    "map_emotion[3]  # Get Categorical Label from Encoded Label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754440,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "k-a5Ww9MhpdO",
    "outputId": "ecf8deb0-f932-4d16-a103-9acc16f252ec",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_val['emotion'] = df_val[\"Emotions\"].apply(lambda x: emotion_map[x])\n",
    "df_test['emotion'] = df_test[\"Emotions\"].apply(lambda x: emotion_map[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695323756839,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UmEtmaFNrakz",
    "tags": []
   },
   "outputs": [],
   "source": [
    "MAX_LEN = 200\n",
    "RANDOM_SEED = 42\n",
    "device = torch.device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695323756840,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6oWORty0p8Xo",
    "outputId": "cb12350b-753a-4c61-c7e8-3c9427f000e9",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:1\n"
     ]
    }
   ],
   "source": [
    "PRE_TRAINED_MODEL_NAME = 't5-large'\n",
    "config = T5Config.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "tokenizer = T5Tokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "clear_output()\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695323770760,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "OtZt1p7ys7XD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class IMDBDataset(Dataset):\n",
    "\n",
    "  def __init__(self, reviews, sentiments, tokenizer, max_len):\n",
    "    self.reviews = reviews\n",
    "    self.sentiments = sentiments\n",
    "    self.tokenizer = tokenizer\n",
    "    self.max_len = max_len\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.reviews)\n",
    "\n",
    "  def __getitem__(self, item):\n",
    "    review = str(self.reviews[item])\n",
    "    sentiment = self.sentiments[item]\n",
    "\n",
    "\n",
    "    encoding = self.tokenizer.encode_plus(\n",
    "      review,\n",
    "      add_special_tokens=True,\n",
    "      max_length=self.max_len,\n",
    "      return_token_type_ids=False,\n",
    "      padding='max_length',\n",
    "      truncation = True,\n",
    "      return_attention_mask=True,\n",
    "      return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    return {\n",
    "      'review': review,\n",
    "      'input_ids': encoding['input_ids'].flatten(),\n",
    "      'attention_mask': encoding['attention_mask'].flatten(),\n",
    "      'sentiments': torch.tensor(sentiment, dtype=torch.long),\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695325189268,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UuOujQajtL5f",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_data_loader(df, tokenizer, max_len, batch_size):\n",
    "  ds = IMDBDataset(\n",
    "    reviews=df.review.to_numpy(),\n",
    "    sentiments=df['sentiment'].to_numpy(),\n",
    "    tokenizer=tokenizer,\n",
    "    max_len=max_len\n",
    "  )\n",
    "\n",
    "  return DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=8\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695323770761,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "3zzA4eBytOqj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "\n",
    "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695325187523,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AoUfRPy0tQgk",
    "outputId": "4cc31257-8fe1-4cce-883e-8aff4ebf2185",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['review', 'input_ids', 'attention_mask', 'sentiments'])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = next(iter(train_data_loader))\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1695325688712,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Y3Nil-yatUqF",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Classifier(nn.Module):\n",
    "  def __init__(self):\n",
    "    super(Classifier, self).__init__()\n",
    "    self.bert = T5EncoderModel.from_pretrained(PRE_TRAINED_MODEL_NAME,config=config)\n",
    "    self.FC = nn.Linear(config.hidden_size,2, bias=False)\n",
    "\n",
    "\n",
    "  def forward(self, input_ids, attention_mask):\n",
    "    with torch.no_grad():\n",
    "      pooled_output = self.bert(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_dict = False\n",
    "      )\n",
    "    pooled_output = torch.mean(pooled_output[0], dim=1) # Taking Averge pooled last layer embedding\n",
    "\n",
    "    binary_out = self.FC(pooled_output)\n",
    "    \n",
    "    return binary_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "id": "HWZ37gsztWzL",
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = Classifier()\n",
    "model = model.to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1695325221207,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWy57v2CxxCM",
    "tags": []
   },
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    if name.startswith('bert'):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1695325221207,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e5iLu13CYlux",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#for name, param in model.named_parameters():\n",
    "#    print(name, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 21,
     "status": "ok",
     "timestamp": 1695325221208,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "nZpqz6yDtYZ4",
    "outputId": "c0ba4f97-cad5-46b8-e1ff-e818ff5b4035",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 200])\n",
      "torch.Size([32, 200])\n"
     ]
    }
   ],
   "source": [
    "input_ids = data['input_ids'].to(device)\n",
    "attention_mask = data['attention_mask'].to(device)\n",
    "sentiments = data['sentiments'].to(device)\n",
    "\n",
    "print(input_ids.shape) # batch size x seq length\n",
    "print(attention_mask.shape) # batch size x seq length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1695325221208,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NwvA-zp7vqc1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#del test\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "executionInfo": {
     "elapsed": 603,
     "status": "ok",
     "timestamp": 1695325221800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "-9Z37OXOtb0q",
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs = model(input_ids, attention_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1695325221800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NHFM3QUhg3n0",
    "outputId": "602bd955-db78-44a6-931e-e5e901bb1856",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.2176, -0.1994],\n",
       "        [-0.0607, -0.0321],\n",
       "        [-0.0013, -0.0131],\n",
       "        [-0.2191, -0.1430],\n",
       "        [ 0.0081, -0.0156],\n",
       "        [-0.0739, -0.1514],\n",
       "        [-0.0201, -0.0204],\n",
       "        [ 0.0068, -0.0097],\n",
       "        [ 0.0090, -0.0087],\n",
       "        [-0.2202, -0.0538],\n",
       "        [-0.0678,  0.0084],\n",
       "        [ 0.0064, -0.0092],\n",
       "        [-0.0938, -0.2230],\n",
       "        [-0.1755,  0.1041],\n",
       "        [ 0.0094, -0.0105],\n",
       "        [-0.0855, -0.1259],\n",
       "        [ 0.0062, -0.0102],\n",
       "        [-0.0630, -0.0834],\n",
       "        [ 0.0078, -0.0104],\n",
       "        [-0.1489, -0.1265],\n",
       "        [-0.2179, -0.0606],\n",
       "        [-0.0300, -0.1825],\n",
       "        [ 0.0056, -0.0063],\n",
       "        [-0.0005, -0.0228],\n",
       "        [-0.1795,  0.0004],\n",
       "        [-0.1922, -0.0534],\n",
       "        [ 0.0051, -0.0086],\n",
       "        [ 0.0050, -0.0096],\n",
       "        [-0.0905, -0.0576],\n",
       "        [-0.0350, -0.0403],\n",
       "        [-0.1166, -0.1345],\n",
       "        [-0.0276,  0.0198]], device='cuda:1', grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1022,
     "status": "ok",
     "timestamp": 1695325845248,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cjMVWA5a_6lf",
    "outputId": "bae8764e-0ce8-4118-e676-267b2c3ac06b",
    "tags": []
   },
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=0.001)\n",
    "total_steps = len(train_data_loader) * EPOCHS\n",
    "\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "  optimizer,\n",
    "  num_warmup_steps=math.floor((1./5)*total_steps),\n",
    "  num_training_steps=total_steps\n",
    ")\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695325845969,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cLFDb4pzbx9W",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_epoch(\n",
    "  model,\n",
    "  data_loader,\n",
    "  loss_fn,\n",
    "  optimizer,\n",
    "  device,\n",
    "  scheduler,\n",
    "  n_examples\n",
    "):\n",
    "  model = model.train()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  for d in data_loader:\n",
    "    input_ids = d[\"input_ids\"].to(device)\n",
    "    attention_mask = d[\"attention_mask\"].to(device)\n",
    "    sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "    outputs = model(\n",
    "      input_ids=input_ids,\n",
    "      attention_mask=attention_mask\n",
    "    ).to(device)\n",
    "\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    loss = loss_fn(outputs, sentiments)\n",
    "\n",
    "    correct_predictions += torch.sum(preds == sentiments)\n",
    "    losses.append(loss.item())\n",
    "\n",
    "\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695325845971,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "z4GAdIawtUue",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def eval_model(model, data_loader, loss_fn, device, n_examples, on_new=False):\n",
    "  model = model.eval()\n",
    "\n",
    "  losses = []\n",
    "  f1s = []\n",
    "\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      ).to(device)\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      loss = loss_fn(outputs, sentiments)\n",
    "\n",
    "      correct_predictions += torch.sum(preds == sentiments)\n",
    "      losses.append(loss.item())\n",
    "\n",
    "      f1s.append(f1_score(sentiments.cpu(), preds.cpu(), average='macro'))\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses), np.mean(f1s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "executionInfo": {
     "elapsed": 3203929,
     "status": "ok",
     "timestamp": 1695329051836,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "IqdIHJsrANr0",
    "outputId": "ee79ec9a-a033-47e0-bc8b-a7b71698465c",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "----------\n",
      "Train loss 0.621671630808094 accuracy 0.6639903653151346\n",
      "Val   loss 0.5532590256461615 accuracy 0.7219357397857993\n",
      "\n",
      "Epoch 2/8\n",
      "----------\n",
      "Train loss 0.5246661497253093 accuracy 0.7377559213167403\n",
      "Val   loss 0.49840183273146427 accuracy 0.7552558508528362\n",
      "\n",
      "Epoch 3/8\n",
      "----------\n",
      "Train loss 0.48308891904679385 accuracy 0.7631473303894019\n",
      "Val   loss 0.47718201973770236 accuracy 0.7742959143197145\n",
      "\n",
      "Epoch 4/8\n",
      "----------\n",
      "Train loss 0.4631938681269534 accuracy 0.779205138498595\n",
      "Val   loss 0.46864220386818994 accuracy 0.7719159063863547\n",
      "\n",
      "Epoch 5/8\n",
      "----------\n",
      "Train loss 0.4530742472668521 accuracy 0.7824668807707749\n",
      "Val   loss 0.466130669735655 accuracy 0.7806426021420071\n",
      "\n",
      "Epoch 6/8\n",
      "----------\n",
      "Train loss 0.44729332762200796 accuracy 0.7889401846647933\n",
      "Val   loss 0.46425050572503973 accuracy 0.7806426021420071\n",
      "\n",
      "Epoch 7/8\n",
      "----------\n",
      "Train loss 0.4421374551461558 accuracy 0.7874347651545565\n",
      "Val   loss 0.46236331930643393 accuracy 0.7814359381197937\n",
      "\n",
      "Epoch 8/8\n",
      "----------\n",
      "Train loss 0.4396394712680033 accuracy 0.791298675230831\n",
      "Val   loss 0.46205451707296735 accuracy 0.7810392701309005\n",
      "\n",
      "CPU times: user 2h 3min 18s, sys: 14.1 s, total: 2h 3min 32s\n",
      "Wall time: 2h 3min 34s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_a = []\n",
    "train_l = []\n",
    "val_a = []\n",
    "val_l = []\n",
    "best_accuracy = 0\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "  print(f'Epoch {epoch + 1}/{EPOCHS}')\n",
    "  print('-' * 10)\n",
    "\n",
    "  train_acc, train_loss = train_epoch(\n",
    "    model,\n",
    "    train_data_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device,\n",
    "    scheduler,\n",
    "    len(df_train)\n",
    "  )\n",
    "\n",
    "  print(f'Train loss {train_loss} accuracy {train_acc}')\n",
    "\n",
    "  val_acc, val_loss, val_f1 = eval_model(\n",
    "    model,\n",
    "    val_data_loader,\n",
    "    loss_fn,\n",
    "    device,\n",
    "    len(df_val)\n",
    "  )\n",
    "\n",
    "  print(f'Val   loss {val_loss} accuracy {val_acc}')\n",
    "  print()\n",
    "\n",
    "  train_a.append(train_acc)\n",
    "  train_l.append(train_loss)\n",
    "  val_a.append(val_acc)\n",
    "  val_l.append(val_loss)\n",
    "\n",
    "  if val_acc > best_accuracy:\n",
    "    torch.save(model.state_dict(), 'baseline_t5_best_model_state.bin')\n",
    "    best_accuracy = val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "executionInfo": {
     "elapsed": 51,
     "status": "ok",
     "timestamp": 1695329137842,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "FowMSU5U7SDQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_a = [i.item() for i in train_a]\n",
    "train_l = [i.item() for i in train_l]\n",
    "val_a = [i.item() for i in val_a]\n",
    "val_l = [i.item() for i in val_l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 480
    },
    "executionInfo": {
     "elapsed": 2888,
     "status": "ok",
     "timestamp": 1695329143450,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aUQbxyTEAPhM",
    "outputId": "b103ca88-1886-4f16-ea21-5b966a08fe7c",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHFCAYAAAAOmtghAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOzdeXhU1f3H8c/MZF+BJISwJZFNAoJKWAJFBAUEQbFVQFQIqC1VCxSVSv1VEdGoVURA0LaG1SpS0WKLIhUVKihCCVRZi4GwBENAsgFZZu7vj0mGDFlIQsIkl/freebJ3DPn3vud0TqfnnPmXothGIYAAABMwurpAgAAAGoT4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYwEYvFUqXHF198USvnO3funCwWi1544YUa7d+rVy/dcssttVJLdTVr1kx33nnnRft98sknslgs+vrrr6t1/Llz52r58uU1LQ/AJfDydAEAas/mzZvdtp999ll9/vnnWr9+vVt7XFxcrZzP19dXmzdvVuvWrWu0/1tvvSWbzVYrtdSVhIQEbd68WZ07d67WfnPnzlXbtm1177331lFlACpCuAFMpFevXm7bERERslqtZdorUlBQIJvNVuXAYbFYqnzs8nTq1KnG+14uoaGhl/Qea5PdbpfdbpePj4+nSwHqNaalgCtUyXTLihUrNGnSJEVFRcnPz0+HDx9Wenq6Jk6cqI4dOyowMFCRkZG6+eaby4wMlTct9cYbb8hiseirr77Sgw8+qLCwMIWHh+uuu+7Sjz/+6Lb/hdNSe/bskcVi0bx58/Tiiy8qOjpaQUFB6tOnj7Zt21bmPSxYsEBt27aVr6+vrrnmGq1cuVKjR4/W1VdfXeXP4aOPPtK1114rf39/xcXFlZlKKm9aat++fbrrrrsUFRUlX19fNWvWTAMHDtT3338vyTnldeDAAa1du9Y1FVi6ptTUVN19992KiIiQr6+v4uLiNHfuXJW+j3HJZzFnzhzNmDFDMTEx8vHx0Zo1axQUFKTJkyeXeS979+6V1WrVvHnzqvz+ATNi5Aa4wj366KO64YYb9Je//EUOh0ONGzdWWlqavL299cwzzygyMlI5OTlauXKl+vbtq40bNyohIeGixx03bpxuu+02vfPOO0pNTdW0adM0fvx4rVmz5qL7zp49W9dcc43mzZsnu92uJ598UkOGDFFqaqoCAwMlOad9Jk+erNGjR2vu3Lk6deqUpk+frsLCQvn7+1fpvX/77bfau3evnnjiCYWHh2vhwoW677771L59e/Xo0aPcfQzD0C233CJfX1+9/PLLatWqlTIzM7Vx40adPn1akrRmzRrdfvvtatmypV599VVJctWUnp6uhIQEWSwWJSUlqWXLlvrwww81efJkHTx4ULNnz3Y738svv6y4uDjNnj1bQUFBiouL09ixY7VkyRI9//zzrs9Dkl5//XUFBgZq3LhxVXr/gGkZAExr3LhxRmBgYLmvffzxx4YkY9CgQRc9TlFRkVFYWGj06dPHuPvuu13tZ8+eNSQZSUlJrraFCxcakoypU6e6HWPmzJmGJOPUqVOutp49exqDBw92be/evduQZMTHxxsOh8PVvmHDBkOS8cEHHxiGYRgFBQVGWFiY0a9fP7dz/O9//zNsNpvRoUOHi76nyMhIIzAw0Dh27JirLTc31wgODjYmT57saiv5nDZv3mwYhmEcOXLEkGS88cYblR6/TZs2bu+txJQpUwyLxWKkpKS4tY8fP96wWq1Gamqq22fRsWNHo6ioyK3v7t27DYvFYixcuNDVlpOTY4SEhBi//vWvL/reAbNjWgq4wv3iF78o02YYhubNm6frrrtOfn5+8vLykre3t7766ivt3r27Sse97bbb3La7dOkiSUpLS7vovsOGDZPFYimz76FDhyRJ3333nU6ePKmRI0e67demTRt17969SvVJUvfu3RUVFeXaDgwMVJs2bVznKU+zZs3UunVrPf/883rttde0Y8cOORyOKp9z/fr1uu6669S1a1e39sTERDkcjjK/ZBsxYkSZNVBXX321Bg4cqNdff93VtnTpUmVnZ+vhhx+uci2AWRFugCtc6S/3EklJSZo0aZL69u2rVatW6ZtvvtG3336rAQMG6OzZs1U6blhYmNu2r6+vJFVp/4vte/LkSUlSZGRkmX3La6vqeUrOVVmNNptNn3/+ufr376/nnntO1157rSIjIzV16lTl5eVd9JwnT54s9zNv3ry56/XSyusrSZMnT9Z3332nDRs2SHJOSfXv379BLNIG6hprboArXOkRkhLLly/XLbfcorlz57q1Z2VlXa6yKlUSSi5coCxJx48fr/PzX3XVVVq8eLEk58Lfd999V88++6wcDofmzJlT6b5hYWFKT08v037s2DFJUnh4uFt7ef98JGnIkCFq166d5s+fr6KiIu3atUszZ86swbsBzIeRGwBlWCwW12hJia1bt+o///mPhypy17lzZzVp0kQrVqxwaz9w4IC2bt16WWu5+uqrNWPGDLVv397t86loBOimm25SSkqK65dVJZYuXSqr1aobb7yxSue1WCz6zW9+ow8++EBPP/20WrZsqREjRlzSewHMgnADoIxhw4bpo48+0qxZs7R+/XrNnz9ft956q2JiYjxdmiTJ29tbTz/9tDZs2KC7775bH3/8sZYvX67BgwerefPmslrr7j9tW7ZsUf/+/fX6669r7dq1Wr9+vZ544gnt3btXAwcOdPW75pprtHXrVv3tb3/T1q1bXWHm8ccfV0REhAYPHqzk5GStXbtWDz/8sN566y1NmTJF0dHRVa4lMTFRAQEB+ve//62JEyfW+wsiApcL01IAypgxY4YKCgq0YMECPffcc+rcubMWLVqkpUuXKiUlxdPlSZImTZokm82m2bNna9WqVbrqqqv0zDPPaNmyZcrOzq6z87Zs2VKtW7fWvHnzdOTIEVmtVrVp00Zz587VQw895Or33HPPKTMzU+PHj1dubq46dOigPXv2KCoqSps3b9b06dP1+OOPKycnR23atNGcOXM0adKkatUSHBysoUOHatWqVXrwwQdr+60CDZbFMEpdNQoAGrCTJ0+qXbt2uvfee8usFzKjs2fPqnXr1hoyZIiWLl3q6XKAeoORGwANUlpammbPnq1+/fqpSZMmSk1N1SuvvKL8/Hz95je/8XR5dSojI0P79u3Tm2++qZ9++knTpk3zdElAvUK4AdAg+fn5af/+/XrnnXd06tQpBQUFqXfv3lq8eLHatWvn6fLq1KpVq/TrX/9aLVq00J///Odq39QTMDumpQAAgKl49NdSGzZs0PDhw9W8eXNZLBZ9+OGHF93nyy+/VLdu3eTn56errrpKb7zxxmWoFAAANBQeDTd5eXnq2rWr5s+fX6X+qampGjp0qPr27avt27fr97//vSZNmqT333+/jisFAAANRb2ZlrJYLPrggw8qvQjV7373O61evdrt3jYTJ07Ujh07tHnz5stRJgAAqOca1ILizZs3a9CgQW5tgwcP1ltvvaXCwkJ5e3uX2Sc/P1/5+fmubYfDoVOnTiksLKzCy5oDAID6xTAM5eTkVOlCnQ0q3Bw/frzMTfEiIyNVVFSkzMzMCm8A+Mwzz1yuEgEAQB06fPiwWrZsWWmfBhVupLI3kSuZVatoFGb69OmaOnWqazsrK0utW7fW4cOHFRISUneFAgCAWpOdna1WrVopODj4on0bVLhp1qxZmTv+ZmRkyMvLy3WX4Av5+vqWuQGgJIWEhBBuAABoYKqypKRB3TgzISFB69atc2v79NNPFR8fX+56GwAAcOXxaLjJzc1VSkqK60Z8qampSklJUVpamiTnlNLYsWNd/SdOnKhDhw5p6tSp2r17t5KTk/XWW2/pscce80j9AACg/vHotNTWrVvVv39/13bJ2phx48Zp8eLFSk9PdwUdSYqNjdWaNWv029/+Vq+//rqaN2+uuXPn6he/+MVlrx0AANRP9eY6N5dLdna2QkNDlZWVxZobAA2S3W5XYWGhp8sAap2Pj0+FP/Ouzvd3g1pQDABXMsMwdPz4cZ0+fdrTpQB1wmq1KjY2Vj4+Ppd0HMINADQQJcGmadOmCggI4EKkMBWHw6Fjx44pPT1drVu3vqR/vwk3ANAA2O12V7Cp6NIXQEMXERGhY8eOqaio6JJ+Bd2gfgoOAFeqkjU2AQEBHq4EqDsl01F2u/2SjkO4AYAGhKkomFlt/ftNuAEAAKZCuAEANCgxMTGaM2eOp8tAPcaCYgBAnbrxxht17bXX1log+fbbbxUYGFgrx4I5EW4AAB5nGIbsdru8vC7+tRQREXEZKrq8qvP+cXFMSwEA6kxiYqK+/PJLvfbaa7JYLLJYLDp48KC++OILWSwWrV27VvHx8fL19dXGjRt14MAB3X777YqMjFRQUJC6d++uf/3rX27HvHBaymKx6C9/+YvuuOMOBQQEqF27dlq9enWldS1fvlzx8fEKDg5Ws2bNNGbMGGVkZLj1+f7773XrrbcqJCREwcHB6tu3rw4cOOB6PTk5WZ06dZKvr6+ioqL0yCOPSJIOHjwoi8Xium+iJJ0+fVoWi0VffPGFJF3S+8/Pz9e0adPUqlUr+fr6ql27dnrrrbdkGIbatm2rl19+2a3/d999J6vV6la72RFuAKCBMgxDZwqKPPKo6p17XnvtNSUkJOjBBx9Uenq60tPT1apVK9fr06ZNU1JSknbv3q0uXbooNzdXQ4cO1b/+9S9t375dgwcP1vDhw93uM1ieZ555RiNHjtTOnTs1dOhQ3XPPPTp16lSF/QsKCvTss89qx44d+vDDD5WamqrExETX60ePHtUNN9wgPz8/rV+/Xtu2bdOECRNUVFQkSVq4cKEefvhh/fKXv9R///tfrV69Wm3btq3SZ1JaTd7/2LFj9e6772ru3LnavXu33njjDQUFBclisWjChAlatGiR2zmSk5PVt29ftWnTptr1NVSMfwFAA3W20K64p9Z65Ny7Zg5WgM/Fv0JCQ0Pl4+OjgIAANWvWrMzrM2fO1MCBA13bYWFh6tq1q2t71qxZ+uCDD7R69WrXyEh5EhMTdffdd0uSnn/+ec2bN09btmzRLbfcUm7/CRMmuJ5fddVVmjt3rnr06KHc3FwFBQXp9ddfV2hoqN59913XxeTat2/vVtejjz6qyZMnu9q6d+9+sY+jjOq+/3379um9997TunXrdPPNN7vqLzF+/Hg99dRT2rJli3r06KHCwkItX75cf/zjH6tdW0PGyA0AwGPi4+PdtvPy8jRt2jTFxcWpUaNGCgoK0p49ey46ctOlSxfX88DAQAUHB5eZZipt+/btuv322xUdHa3g4GDdeOONkuQ6T0pKivr27VvuVXIzMjJ07Ngx3XTTTVV9mxWq7vtPSUmRzWZTv379yj1eVFSUbr31ViUnJ0uS/vGPf+jcuXO66667LrnWhoSRGwBooPy9bdo1c7DHzl0bLvzV0+OPP661a9fq5ZdfVtu2beXv768777xTBQUFlR7nwhBisVjkcDjK7ZuXl6dBgwZp0KBBWr58uSIiIpSWlqbBgwe7zuPv71/huSp7TZLrrtalp+4quot7dd//xc4tSQ888IDuu+8+vfrqq1q0aJFGjRp1xV3ZmnADAA2UxWKp0tSQp/n4+FT5cvobN25UYmKi7rjjDklSbm6uDh48WKv17NmzR5mZmXrhhRdc63+2bt3q1qdLly5asmSJCgsLywSn4OBgxcTE6LPPPlP//v3LHL/k11zp6em67rrrJMltcXFlLvb+r7nmGjkcDn355ZeuaakLDR06VIGBgVq4cKE+/vhjbdiwoUrnNhOmpQAAdSomJkbffPONDh48qMzMzApHVCSpbdu2WrVqlVJSUrRjxw6NGTOm0v410bp1a/n4+GjevHn64YcftHr1aj377LNufR555BFlZ2dr9OjR2rp1q/bv369ly5Zp7969kqQZM2bolVde0dy5c7V//3795z//0bx58yQ5R1d69eqlF154Qbt27dKGDRv0f//3f1Wq7WLvPyYmRuPGjdOECRNcC6G/+OILvffee64+NptNiYmJmj59utq2bauEhIRL/cgaHMINAKBOPfbYY7LZbIqLi3NNAVXk1VdfVePGjdW7d28NHz5cgwcP1vXXX1+r9URERGjx4sVauXKl4uLi9MILL5T5+XRYWJjWr1+v3Nxc9evXT926ddOf//xn1yjOuHHjNGfOHC1YsECdOnXSsGHDtH//ftf+ycnJKiwsVHx8vCZPnqxZs2ZVqbaqvP+FCxfqzjvv1EMPPaSrr75aDz74oPLy8tz63H///SooKHBbOH0lsRhV/T2fSWRnZys0NFRZWVkKCQnxdDkAUCXnzp1TamqqYmNj5efn5+lyUM999dVXuvHGG3XkyBFFRkZ6upwqq+zf8+p8f9f/yVoAAFAl+fn5Onz4sP7whz9o5MiRDSrY1CampQAAMIl33nlHHTp0UFZWll566SVPl+MxhBsAAEwiMTFRdrtd27ZtU4sWLTxdjscQbgAAgKkQbgAAgKkQbgAAgKkQbgAAgKkQbgAAgKkQbgAAgKkQbgAA9V5MTIzmzJnj2rZYLPrwww8r7H/w4EFZLJYq37Cyro+Dy4srFAMAGpz09HQ1bty4Vo+ZmJio06dPu4WmVq1aKT09XeHh4bV6LtQtwg0AoMFp1qzZZTmPzWa7bOeqbwoLC103Cm1omJYCANSZN998Uy1atJDD4XBrv+222zRu3DhJ0oEDB3T77bcrMjJSQUFB6t69u/71r39VetwLp6W2bNmi6667Tn5+foqPj9f27dvd+tvtdt1///2KjY2Vv7+/OnTooNdee831+owZM7RkyRL9/e9/l8VikcVi0RdffFHutNSXX36pHj16yNfXV1FRUXriiSdUVFTkev3GG2/UpEmTNG3aNDVp0kTNmjXTjBkzKn0/3377rQYOHKjw8HCFhoaqX79++s9//uPW5/Tp0/rlL3+pyMhI+fn5qXPnzvrHP/7hev2rr75Sv379FBAQoMaNG2vw4MH66aefJJWd1pOka6+91q0ui8WiN954Q7fffrsCAwM1a9asi35uJZKTk9WpUyfXZ/LII49IkiZMmKBhw4a59S0qKlKzZs2UnJxc6WdyKRi5AYCGyjCkwjOeObd3gGSxXLTbXXfdpUmTJunzzz/XTTfdJEn66aeftHbtWn300UeSpNzcXA0dOlSzZs2Sn5+flixZouHDh2vv3r1q3br1Rc+Rl5enYcOGacCAAVq+fLlSU1M1efJktz4Oh0MtW7bUe++9p/DwcG3atEm//OUvFRUVpZEjR+qxxx7T7t27lZ2drUWLFkmSmjRpomPHjrkd5+jRoxo6dKgSExO1dOlS7dmzRw8++KD8/PzcgsKSJUs0depUffPNN9q8ebMSExPVp08fDRw4sNz3kJOTo3Hjxmnu3LmSpFdeeUVDhw7V/v37FRwcLIfDoSFDhignJ0fLly9XmzZttGvXLtlsNklSSkqKbrrpJk2YMEFz586Vl5eXPv/8c9nt9ot+fqU9/fTTSkpK0quvviqbzXbRz02SFi5cqKlTp+qFF17QkCFDlJWVpa+++kqS9MADD+iGG25Qenq6oqKiJElr1qxRbm6ua/+6QLgBgIaq8Iz0fHPPnPv3xySfwIt2a9KkiW655Rb99a9/dYWblStXqkmTJq7trl27qmvXrq59Zs2apQ8++ECrV692jQBU5u2335bdbldycrICAgLUqVMnHTlyRL/+9a9dfby9vfXMM8+4tmNjY7Vp0ya99957GjlypIKCguTv76/8/PxKp6EWLFigVq1aaf78+bJYLLr66qt17Ngx/e53v9NTTz0lq9U5IdKlSxc9/fTTkqR27dpp/vz5+uyzzyoMNwMGDHDbfvPNN9W4cWN9+eWXGjZsmP71r39py5Yt2r17t9q3by9Juuqqq1z9X3rpJcXHx2vBggWutk6dOl30s7vQmDFjNGHCBLe2yj43yfnP69FHH3ULlN27d5ck9e7dWx06dNCyZcs0bdo0SdKiRYt01113KSgoqNr1VRXTUgCAOnXPPffo/fffV35+viRnGBk9erRr1CEvL0/Tpk1TXFycGjVqpKCgIO3Zs0dpaWlVOv7u3bvVtWtXBQQEuNoSEhLK9HvjjTcUHx+viIgIBQUF6c9//nOVz1H6XAkJCbKUGrXq06ePcnNzdeTIEVdbly5d3PaLiopSRkZGhcfNyMjQxIkT1b59e4WGhio0NFS5ubmu+lJSUtSyZUtXsLlQycjNpYqPjy/TVtnnlpGRoWPHjlV67gceeMA1GpaRkaF//vOfZQJUbWPkBgAaKu8A5wiKp85dRcOHD5fD4dA///lPde/eXRs3btTs2bNdrz/++ONau3atXn75ZbVt21b+/v668847VVBQUKXjG4Zx0T7vvfeefvvb3+qVV15RQkKCgoOD9cc//lHffPNNld9HybksF0zHlZy/dPuFC3EtFkuZdUelJSYm6sSJE5ozZ46io6Pl6+urhIQE12fg7+9faV0Xe91qtZb5nAoLC8v0Cwx0H4272Od2sfNK0tixY/XEE09o8+bN2rx5s2JiYtS3b9+L7ncpCDcA0FBZLFWaGvI0f39//fznP9fbb7+t//3vf2rfvr26devmen3jxo1KTEzUHXfcIcm5BufgwYNVPn5cXJyWLVums2fPur5sv/76a7c+GzduVO/evfXQQw+52g4cOODWx8fH56JrVOLi4vT++++7hZxNmzYpODhYLVq0qHLNF9q4caMWLFigoUOHSpIOHz6szMxM1+tdunTRkSNHtG/fvnJHb7p06aLPPvvMbQqptIiICKWnp7u2s7OzlZqaWqW6KvvcgoODFRMTo88++0z9+/cv9xhhYWEaMWKEFi1apM2bN2v8+PEXPe+lYloKAFDn7rnnHv3zn/9UcnKy7r33XrfX2rZtq1WrViklJUU7duzQmDFjKh3luNCYMWNktVp1//33a9euXVqzZo1efvnlMufYunWr1q5dq3379ukPf/iDvv32W7c+MTEx2rlzp/bu3avMzMxyRzYeeughHT58WL/5zW+0Z88e/f3vf9fTTz+tqVOnutbb1ETbtm21bNky7d69W998843uuecet1GRfv366YYbbtAvfvELrVu3Tqmpqfr444/1ySefSJKmT5+ub7/9Vg899JB27typPXv2aOHCha6ANGDAAC1btkwbN27Ud999p3HjxrmmBS9W18U+txkzZuiVV17R3LlztX//fv3nP//RvHnz3Po88MADWrJkiXbv3u36lVxdItwAAOrcgAED1KRJE+3du1djxoxxe+3VV19V48aN1bt3bw0fPlyDBw/W9ddfX+VjBwUF6aOPPtKuXbt03XXX6cknn9SLL77o1mfixIn6+c9/rlGjRqlnz546efKk22iEJD344IPq0KGDa31JyS9+SmvRooXWrFmjLVu2qGvXrpo4caLuv/9+/d///V81Po2ykpOT9dNPP+m6667Tfffdp0mTJqlp06Zufd5//311795dd999t+Li4jRt2jTXSFP79u316aefaseOHerRo4cSEhL097//XV5ezgma6dOn64YbbtCwYcM0dOhQjRgxQm3atLloXVX53MaNG6c5c+ZowYIF6tSpk4YNG6b9+/e79bn55psVFRWlwYMHq3nzul8EbzGqMllpItnZ2QoNDVVWVpZCQkI8XQ4AVMm5c+eUmpqq2NhY+fn5ebocoFrOnDmj5s2bKzk5WT//+c8r7FfZv+fV+f5mzQ0AAKgTDodDx48f1yuvvKLQ0FDddtttl+W8hBsAAFAn0tLSFBsbq5YtW2rx4sWuabK6RrgBAAB1IiYmpko/1a9tLCgGAACmQrgBgAbkCvsNCK4wtfXvN+EGABqAkivenjnjoRtlApdByRWZq3INnsqw5gYAGgCbzaZGjRq57k8UEBBQ5jYAQEPmcDh04sQJBQQEXPLCY8INADQQJXerruwGjEBDZrVa1bp160sO7oQbAGggLBaLoqKi1LRp03JvDQA0dD4+Ppd0G4sShBsAaGBsNtslr0kAzIwFxQAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQINwAAwFQ8Hm4WLFig2NhY+fn5qVu3btq4cWOl/d9++2117dpVAQEBioqK0vjx43Xy5MnLVC0AAKjvPBpuVqxYoSlTpujJJ5/U9u3b1bdvXw0ZMkRpaWnl9v/3v/+tsWPH6v7779f333+vlStX6ttvv9UDDzxwmSsHAAD1lUfDzezZs3X//ffrgQceUMeOHTVnzhy1atVKCxcuLLf/119/rZiYGE2aNEmxsbH62c9+pl/96lfaunXrZa4cAADUVx4LNwUFBdq2bZsGDRrk1j5o0CBt2rSp3H169+6tI0eOaM2aNTIMQz/++KP+9re/6dZbb63wPPn5+crOznZ7AAAA8/JYuMnMzJTdbldkZKRbe2RkpI4fP17uPr1799bbb7+tUaNGycfHR82aNVOjRo00b968Cs+TlJSk0NBQ16NVq1a1+j4AAED94vEFxRaLxW3bMIwybSV27dqlSZMm6amnntK2bdv0ySefKDU1VRMnTqzw+NOnT1dWVpbrcfjw4VqtHwAA1C9enjpxeHi4bDZbmVGajIyMMqM5JZKSktSnTx89/vjjkqQuXbooMDBQffv21axZsxQVFVVmH19fX/n6+tb+GwAAoJ4wDEOFdkOFdoeK7IYK7I6LPi+0O1RwwfPKXit57nxU/ryRv7cWje/hsc/DY+HGx8dH3bp107p163THHXe42tetW6fbb7+93H3OnDkjLy/3km02myTnP1gAQOUcDkP5RQ6dLbTrXKHd9df5cOhsgV3niuzFfx06V1C6n3O//FL7lbSfK7VtdxiyWizFD+cIvdUq2YrbLBbJarHIZrU4Xyvetlokq/WC/SySzVqyX+m+pZ5bnX1tpdpK71vZOUr62qwqcw7nvnI/V6ljlD3n+b9FdkOFDq2+LVoAACAASURBVIcKi4q/8B0OFRYVh4BSz4scDhVU8LywyBlGiirYt3SYKHLUr+/AiGDPDip4LNxI0tSpU3XfffcpPj5eCQkJ+tOf/qS0tDTXNNP06dN19OhRLV26VJI0fPhwPfjgg1q4cKEGDx6s9PR0TZkyRT169FDz5s09+VYAoMZKAse5i4QGt3DhCiEOnSuyO0NISSgpdJQNLcXP84scnn67uEx8bFZ52yzy9rLKy2qVj+u5Rd42q3wqeO7tZZV3FZ57WS3y8bLK21b2eYCPR+OFZ8PNqFGjdPLkSc2cOVPp6enq3Lmz1qxZo+joaElSenq62zVvEhMTlZOTo/nz5+vRRx9Vo0aNNGDAAL344oueegsA6jnDMGR3GCpyOP8frt3hHL4vcpwfZi/dVmg3VFTS5nA+L3nN1c/ucL1WXlvpkZHyRkTyi8pue4KPl1V+Xlb5+9jk522Tv7dNvt42+XtbXdt+rofVte1fvO3ntm2Tv49Vvl42edkscjgkh2HIMJx/7YYhwzDkMJxhzlHc7jDOP3f+szr/3GFIdofhdpzz+1+4X+nt4nM6HDIMQ4bDIcNhyGE4ZBiO4r+GHA5DMhxyOBxyGIZU3NdhOPsbxf2ctTicxynuI0OuYzkckgyH81iGIUOGvKxWedksxX9t8rJJXjabvK1W2WzOgGCz2eRttcir+LmP7fzzkhBhszkDQ+nnXjarvG3Oz/n8tlVeXs7je9usrhErudawFv8tvV3ua+X0q8prFayV9RSLcYXN52RnZys0NFRZWVkKCQnxdDlAg2cYhk6fKVRmbr5O5OQr+1zhBUHhfJAo03ZBeHC2uwcFV1up4OHsUyqgOByy2w23fQpLhZFqvBv5K19BOqdAy1kFKl+BOqtAyzkFFf8NVPHDcrbcfgHKl1U1DyuW4mkNi85Py1gslrLbFucvQiwXTIdYVNLn/H6u10r21/nX6oZR/GVvFD8v+Stnu1ubUUn/8vpdrH/x8eFZwVHSo3tq9ZDV+f727LgRgHrJMAzl5BfpRE6+MnPydSL3/N8TOfnKzC0o/ut8VC9A1GqlrjASYDmnxjrnFkYCLPkKsp11CyMBlnMK0jkFWc4qqPh5QHFo8dc52S4hmNS64u9s1BMWq86PeFzw98LXJJ0PXsXPJfftcl8rp9+FrzUEHh43IdwAl6IgT8pOl7KPStnHpJxjkr1Islolq5dksTn/Wm3Oh9u2l/M/iG7btvN93fa/xOMVDxnn5Re5RlhK/p7IydeJUmGl5G91p0pC/b0VEeyrUH9v5zx/8dy7zeqc9/eyWeVtkfyt+Qow8hVoOasAw/nw1xn5Os7J3zgrP8cZ+TnOyNdxRt6Os/K1n5GP/Yy87WfkXZQnL/sZeRXlyasoT7aiM7IYdRFGLJJPkOQb5PzrEyj5Bl+8zTdY8g5wfuZXuvKCQHmhQMXB4KKBQdXsX4VzV9jfWs6+9ZBRk/BU2WtVCFYXhpYK9/PsZ0a4ASpyLtsZWEqCS3nPz532dJVVYpdVdsMqq6xqIpsayaJY2eSQVXZZVSSbHIbF+VdWFVltcvhYZVhsstpssti8ZbPZZPPylpeXl7y8veXj5S1vbx/5+HjLx8dbNquXM1AZDqkgV8rPcf4tyJPyc4uf5xZPG9QBn6CyQcMVQsprKyeY+AQ6n3sHOAMlUJ/V5+DlYYQbXHkMQzr7U6mgcmF4SXc+L8ip2vF8gqSQ5s5HcHPJy1dyFDm/xB1FksNevG0vfl7ZdpHkcMjhKJTDbpe9qEgOe6EcDrsMe+ljOWQ1imQxHLLJLpscslkqHgZ2vn6RUFHRfyMdxY/Cqn0cVeYWRgIln+BSQSOoCmGlVH/CCIBSCDcwF4dDOpNZ+WhL9jGp6FzVjufXSAppcT68lPfcr2oL04vsDp3KK3CtW7lw7cqJnHxl5jnXtZw+U70k4WOzKiLIR02DvRUZaFPTYC9FBHgpItBLEQE2hQV6KczfqiYBXgrwMpxTORcNXs6g5b5d3Mdt+4IgZ7GUE0xKh5FAyTuQMAKgzhBu0HDYi6TcH8+HlJz08kddHFUMBgHhFQeWkBZSSJTzi7gKss4W6mBmng6ezNOP2efKhJYTOfk6daagWmvsvKwWhQf5KjzYRxFBvgoP8lVEsPNR8rzkb4ifV4W3LQGAKw3hBvVDUUFxWKlkjUvu8Squ17BIwc2Kp4miLggszc+3e/tVq8QzBUVKzczTwcwzOngyTz+ccIaZg5l5OplXUKVjWC1Sk8CSYOLjCisRQWVDSyN/b1mtBBYAqC7CDepe4dlSYaWC8JKXUbVjWb2KA0sl00RBkZLNu0alniu0K+3UmeIQk6fU4odzRCa/0n0jgn0VGxaoqEZ+5YaViGBfNQn0kY3AAgB1inCDS2cY0ulDUvpO6cSesuHl7E9VO47Np5JpouLngRGX/DPbQrtDh085R19SM88oNTNXBzOdgeZY1tlKp44aB3grJjxQseGBig0LdD2PCQ9UkC//cwKA+oD/GqN67EVS5l5nkDn+X+n4TufjXFbl+3kHlD89VDq4BITV2s8a7Q5Dx06fdY26lJ5COvzTWdkruclcsK+XYsLPB5fY8ADFhDmfNwrwqZX6AAB1h3CDihWckX78Xjq+ozjM7JR+3CXZy5mesXpLTTtKkZ2lRq3Lhhi/0Fq/HoNhGDqefc61DiY1M1epxeth0k6eUYG94vU5/t42RYcFFIeX0kEmUGGBPizOBYAGjHADpzOnnOGlJMSk75RO7i9/Aa9PkNTsGqlZFymqi/NvxNWSV+2PahiGoczcguIppDy3tTCHTp7R2UJ7hfv62KxqHeYcdbkqIlAxYYGKCQ/QVeFBigzxJcAAgEkRbq40huFcB1M6xBzfKWUdLr9/YNPzAabkb+PYWr9GyekzBa4ppNQTeUo9ecb50+rMPOXkF1W4n81qUavG/m4jLyVTSM0b+bN4FwCuQIQbM3PYpZP/Kw4wO86vkzl7qvz+jWNKhZiuzr/BzWqtnNz8Iteoi+uXSMXrYH6q5KJ1FovUPNT/gimkAMWGB6llY39527gYHADgPMKNWRSekzJ2uY/G/Pi9VHimbF+LzTmN5DYic41zXcwlOldody3cdfsl0sk8ncip/KfUkSG+F0whOcNM6yYB8vPmRoQAgKoh3DRE57KcIzClp5Yy9zovf38h7wDnIt+SANOsi9Q0rtoXsKuM3WHoi70ZWrL5kDbuP1HpT6nDAn3KnUKKCQ9QgA//OgIALh3fJvWZYUg5x91HY47vlH46WH5//yalRmO6Ov+Gtbnk68JU5PSZAq3cekTLvj6ktFPnR4hC/LzK/AqpZCQm1L9mF9cDAKCqCDf1hcMh/ZQqpe9wDzN5J8rvH9rKfZFvVBfnT68vwy+Adh3L1tLNB/VhylGdK3T+mirU31ujurfSmB6tFR0WwC+RAAAeQ7jxhKIC55V83UZkvpMKcsr2tVilsHZlf7EU0OSyllxod2jt98e1dNMhbTl4fkFyx6gQjUuI1u3XtpC/D+tiAACeR7ipa/k5zoW9pX+xlLG7/DtX23ylyE7uU0tN4ySfgMtfd7GMnHN6d8thvf3NIde9lbysFt3SuZnG9Y5RfHRjRmkAAPUK4aY25Z5wv5pv+k7p1A+Syllh6xtadjQmvL1k8/w/EsMwtP3waS3ddFD//G+6Cu3O+sODfDWmZ2uN6dFazUJrb0EyAAC1yfPfpGZx/DvpjT7lvxYcVXZ9TKPoy7I+pjrOFdr1j53pWrr5oHYeOX+vqOtaN1Ji7xgN6RwlHy+uKQMAqN8IN7UlvL1zWqlRK/drxzTrKgVFeLq6Sh09fVZvf31I7357WKfyCiRJPl5W3da1ucYmRKtLy0YerhAAgKoj3NQWLx/piUOSt7+nK6kSwzC0+YeTWrLpoNbt+lElN8luHuqnexOiNSq+lcKCfD1bJAAANUC4qU0NINjk5Rdp1fajWrrpoPZn5Lrae7cJ07jeMbrp6qby4nYGAIAGjHBzhfjhRK6WfX1If9t6xHUjygAfm35xfUuNTYhWu8hgD1cIAEDtINyYmMNh6It9GVqy6ZC+3Hf+YoCx4YEamxCtX3RrqRA/rhgMADAXwo0JZZ0p1Mpth7V08/nbIlgs0oAOTTW2d4z6tg2X1Vq/fqkFAEBtIdyYyO50520RPth+/rYIIX5eGtW9le7rFaPWYZ67GCAAAJcL4aaBK7Q79On3P2rJ5oPaknr+tghXNwvWuN4xGsFtEQAAVxjCTQOVmZuvd75J09vfpOl49jlJks1q0S2dnLdF6B7DbREAAFcmwk0Dsz3tJy3dfEj/3JmuArtz6ik8yEdjerTWmJ7R3BYBAHDFI9w0AOcK7fpn8W0RdlxwW4RxCTEack0z+Xox9QQAgES4qdeOnT6r5eXcFmF4l+Ya15vbIgAAUB7CTT1TcluEpZsO6dNdx91ui3BPr2iN7s5tEQAAqAzhpp7Iyy/SB9uPaunmg9r34/nbIiRc5bwtws0duS0CAABVQbjxsNTMPC3bfEgrtx1Wzrnzt0X4+fUtNDYhRu25LQIAANVCuPEAh8PQl/tOaPGmg263RYgJC9DYhBj9oltLhfpzWwQAAGqCcHMZZZ0t1Mqth7Xs60M6dPL8bRH6d2iqsQnRuqFdBLdFAADgEhFuLoM9x7O1ZNMhfbj9qM4W2iU5b4swMr6V7u0VrZjwQA9XCACAeRBu6kiR3aFPd/2oJZsO6psLboswNiFGI65rrgAfPn4AAGob3661LDM3X+9ucd4WIT3r/G0RBneK1LiEGPWIbcJtEQAAqEOEm1py+NQZvbpun/5xwW0R7u7RWmN6tlZUqL+HKwQA4MpAuKklFov0YcpROQzp2laNNK53tIZeE8VtEQAAuMwIN7WkZeMA/WFYnK5v3VhdW3FbBAAAPIVwU4vG94n1dAkAAFzxuJ4/AAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFY+HmwULFig2NlZ+fn7q1q2bNm7cWGn//Px8Pfnkk4qOjpavr6/atGmj5OTky1QtAACo77w8efIVK1ZoypQpWrBggfr06aM333xTQ4YM0a5du9S6dety9xk5cqR+/PFHvfXWW2rbtq0yMjJUVFR0mSsHAAD1lcUwDMNTJ+/Zs6euv/56LVy40NXWsWNHjRgxQklJSWX6f/LJJxo9erR++OEHNWnSpEbnzM7OVmhoqLKyshQSElLj2gEAwOVTne9vj01LFRQUaNu2bRo0aJBb+6BBg7Rp06Zy91m9erXi4+P10ksvqUWLFmrfvr0ee+wxnT17tsLz5OfnKzs72+0BAADMy2PTUpmZmbLb7YqMjHRrj4yM1PHjx8vd54cfftC///1v+fn56YMPPlBmZqYeeughnTp1qsJ1N0lJSXrmmWdqvX4AAFA/eXxBscVicds2DKNMWwmHwyGLxaK3335bPXr00NChQzV79mwtXry4wtGb6dOnKysry/U4fPhwrb8HAABQf3hs5CY8PFw2m63MKE1GRkaZ0ZwSUVFRatGihUJDQ11tHTt2lGEYOnLkiNq1a1dmH19fX/n6+tZu8QAAoN7y2MiNj4+PunXrpnXr1rm1r1u3Tr179y53nz59+ujYsWPKzc11te3bt09Wq1UtW7as03oBAEDD4NFpqalTp+ovf/mLkpOTtXv3bv32t79VWlqaJk6cKMk5pTR27FhX/zFjxigsLEzjx4/Xrl27tGHDBj3++OOaMGGC/P39PfU2AABAPeLR69yMGjVKJ0+e1MyZM5Wenq7OnTtrzZo1io6OliSlp6crLS3N1T8oKEjr1q3Tb37zG8XHxyssLEwjR47UrFmzPPUWAABAPePR69x4Ate5AQCg4anT69ykpqbWuDAAAIC6Vu1w07ZtW/Xv31/Lly/XuXPn6qImAACAGqt2uNmxY4euu+46Pfroo2rWrJl+9atfacuWLXVRGwAAQLVVO9x07txZs2fP1tGjR7Vo0SIdP35cP/vZz9SpUyfNnj1bJ06cqIs6AQAAqqTGPwX38vLSHXfcoffee08vvviiDhw4oMcee0wtW7bU2LFjlZ6eXpt1AgAAVEmNw83WrVv10EMPKSoqSrNnz9Zjjz2mAwcOaP369Tp69Khuv/322qwTAACgSqp9nZvZs2dr0aJF2rt3r4YOHaqlS5dq6NChslqdOSk2NlZvvvmmrr766lovFgAA4GKqHW4WLlyoCRMmaPz48WrWrFm5fVq3bq233nrrkosDAACoLi7iBwAA6r06vYjfokWLtHLlyjLtK1eu1JIlS6p7OAAAgFpV7XDzwgsvKDw8vEx706ZN9fzzz9dKUQAAADVV7XBz6NAhxcbGlmmPjo52u8klAACAJ1Q73DRt2lQ7d+4s075jxw6FhYXVSlEAAAA1Ve1wM3r0aE2aNEmff/657Ha77Ha71q9fr8mTJ2v06NF1USMAAECVVfun4LNmzdKhQ4d00003ycvLubvD4dDYsWNZcwMAADyuxj8F37dvn3bs2CF/f39dc801io6Oru3a6gQ/BQcAoOGpzvd3tUduSrRv317t27ev6e4AAAB1okbh5siRI1q9erXS0tJUUFDg9trs2bNrpTAAAICaqHa4+eyzz3TbbbcpNjZWe/fuVefOnXXw4EEZhqHrr7++LmoEAACosmr/Wmr69Ol69NFH9d1338nPz0/vv/++Dh8+rH79+umuu+6qixoBAACqrNrhZvfu3Ro3bpwkycvLS2fPnlVQUJBmzpypF198sdYLBAAAqI5qh5vAwEDl5+dLkpo3b64DBw64XsvMzKy9ygAAAGqg2mtuevXqpa+++kpxcXG69dZb9eijj+q///2vVq1apV69etVFjQAAAFVW7XAze/Zs5ebmSpJmzJih3NxcrVixQm3bttWrr75a6wUCAABUR7XCjd1u1+HDh9WlSxdJUkBAgBYsWFAnhQEAANREtdbc2Gw2DR48WKdPn66regAAAC5JtRcUX3PNNfrhhx/qohYAAIBLVu1w89xzz+mxxx7TP/7xD6Wnpys7O9vtAQAA4EnVvnGm1Xo+D1ksFtdzwzBksVhkt9trr7o6wI0zAQBoeOr0xpmff/55jQsDAACoa9UON/369auLOgAAAGpFtcPNhg0bKn39hhtuqHExAAAAl6ra4ebGG28s01Z67U19X3MDAADMrdq/lvrpp5/cHhkZGfrkk0/UvXt3ffrpp3VRIwAAQJVVe+QmNDS0TNvAgQPl6+ur3/72t9q2bVutFAYAAFAT1R65qUhERIT27t1bW4cDAACokWqP3OzcudNt2zAMpaen64UXXlDXrl1rrTAAAICaqHa4ufbaa2WxWHThtf969eql5OTkWisMAACgJqodblJTU922rVarIiIi5OfnV2tFAQAA1FS1w010dHRd1AEAAFArqr2geNKkSZo7d26Z9vnz52vKlCm1UhQAAEBNVTvcvP/+++rTp0+Z9t69e+tvf/tbrRQFAABQU9UONydPniz3WjchISHKzMyslaIAAABqqtrhpm3btvrkk0/KtH/88ce66qqraqUoAACAmqr2guKpU6fqkUce0YkTJzRgwABJ0meffaZXXnlFc+bMqfUCAQAAqqPa4WbChAnKz8/Xc889p2effVaSFBMTo4ULF2rs2LG1XiAAAEB1WIwLr8ZXDSdOnJC/v7+CgoJqs6Y6lZ2drdDQUGVlZSkkJMTT5QAAgCqozvd3jS7iV1RUpHbt2ikiIsLVvn//fnl7eysmJqbaBQMAANSWai8oTkxM1KZNm8q0f/PNN0pMTKyNmgAAAGqs2uFm+/bt5V7nplevXkpJSamVogAAAGqq2uHGYrEoJyenTHtWVpbsdnutFAUAAFBT1Q43ffv2VVJSkluQsdvtSkpK0s9+9rNaLQ4AAKC6qr2g+KWXXtINN9ygDh06qG/fvpKkjRs3KisrS59//nmtFwgAAFAd1R65iYuL086dOzVy5EhlZGQoJydHY8eO1b59+1RUVFQXNQIAAFTZJV3nRpJOnz6tt99+W8nJyUpJSan36264zg0AAA1Pdb6/qz1yU2L9+vW699571bx5c82fP19DhgzR1q1ba3o4AACAWlGtNTdHjhzR4sWLlZycrLy8PI0cOVKFhYV6//33FRcXV1c1AgAAVFmVR26GDh2quLg47dq1S/PmzdOxY8c0b968uqwNAACg2qo8cvPpp59q0qRJ+vWvf6127drVZU0AAAA1VuWRm40bNyonJ0fx8fHq2bOn5s+frxMnTtRlbQAAANVW5XCTkJCgP//5z0pPT9evfvUrvfvuu2rRooUcDofWrVtX7lWLAQAALrdL+in43r179dZbb2nZsmU6ffq0Bg4cqNWrV9dmfbWOn4IDANDwXJafgktShw4d9NJLL+nIkSN65513LuVQAAAAteKSwk0Jm82mESNG1GjUZsGCBYqNjZWfn5+6deumjRs3Vmm/r776Sl5eXrr22murfU4AAGBetRJuamrFihWaMmWKnnzySW3fvl19+/bVkCFDlJaWVul+WVlZGjt2rG666abLVCkAAGgoLvn2C5eiZ8+euv7667Vw4UJXW8eOHTVixAglJSVVuN/o0aPVrl072Ww2ffjhh0pJSanyOVlzAwBAw3PZ1txcioKCAm3btk2DBg1yax80aJA2bdpU4X6LFi3SgQMH9PTTT1fpPPn5+crOznZ7AAAA8/JYuMnMzJTdbldkZKRbe2RkpI4fP17uPvv379cTTzyht99+W15eVbv+YFJSkkJDQ12PVq1aXXLtAACg/vLomhtJslgsbtuGYZRpkyS73a4xY8bomWeeUfv27at8/OnTpysrK8v1OHz48CXXDAAA6q9q3TizNoWHh8tms5UZpcnIyCgzmiNJOTk52rp1q7Zv365HHnlEkuRwOGQYhry8vPTpp59qwIABZfbz9fWVr69v3bwJAABQ73hs5MbHx0fdunXTunXr3NrXrVun3r17l+kfEhKi//73v0pJSXE9Jk6cqA4dOiglJUU9e/a8XKUDAIB6zGMjN5I0depU3XfffYqPj1dCQoL+9Kc/KS0tTRMnTpTknFI6evSoli5dKqvVqs6dO7vt37RpU/n5+ZVpBwAAVy6PhptRo0bp5MmTmjlzptLT09W5c2etWbNG0dHRkqT09PSLXvMGAACgNI9e58YTuM4NAAANT4O4zg0AAEBdINwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABT8Xi4WbBggWJjY+Xn56du3bpp48aNFfZdtWqVBg4cqIiICIWEhCghIUFr1669jNUCAID6zqPhZsWKFZoyZYqefPJJbd++XX379tWQIUOUlpZWbv8NGzZo4MCBWrNmjbZt26b+/ftr+PDh2r59+2WuHAAA1FcWwzAMT528Z8+euv7667Vw4UJXW8eOHTVixAglJSVV6RidOnXSqFGj9NRTT1Wpf3Z2tkJDQ5WVlaWQkJAa1Q0AAC6v6nx/e2zkpqCgQNu2bdOgQYPc2gcNGqRNmzZV6RgOh0M5OTlq0qRJhX3y8/OVnZ3t9gAAAOblsXCTmZkpu92uyMhIt/bIyEgdP368Ssd45ZVXlJeXp5EjR1bYJykpSaGhoa5Hq1atLqluAABQv3l8QbHFYnHbNgyjTFt53nnnHc2YMUMrVqxQ06ZNK+w3ffp0ZWVluR6HDx++5JoBAED95eWpE4eHh8tms5UZpcnIyCgzmnOhFStW6P7779fKlSt18803V9rX19dXvr6+l1wvAABoGDw2cuPj46Nu3bpp3bp1bu3r1q1T7969K9zvnXfeUWJiov7617/q1ltvresyAQBAA+OxkRtJmjp1qu677z7Fx8crISFBf/rTn5SWlqaJEydKck4pHT16VEuXLpXkDDZjx47Va6+9pl69erlGffz9/RUaGuqx9wEAAOoPj4abUaNG6eTJk5o5c6bS09PVuXNnrVmzRtHR0ZKk9PR0t2vevPnmmyoqKtLDDz+shx9+2NU+btw4LV68+HKXDwAA6iGPXufGE7jODQAADU+DuM4NAABAXSDcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAnJF+cgAADSlJREFUUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAU/F4uFmwYIFiY2Pl5+enbt26aePGjZX2//LLL9WtWzf5+fnpqquu0htvvHGZKgUAAA2BR8PNihUrNGXKFD355JPavn27+vbtqyFDhigtLa3c/qmpqRo6dKj69u2r7du36/e//70mTZqk999//zJXDgAA6iuLYRiGp07es2dPXX/99Vq4cKGrrWPHjhoxYoSSkpLK9P/d736n1atXa/fu3a62iRMnaseOHdq8eXOVzpmdna3Q0FBlZWUpJCTk0t8EAACoc9X5/vbYyE1BQYG2bdumQYMGubUPGjRImzZtKnefzZs3l+k/ePBgbd26VYWFhXVWKwAAaDi8PHXizMxM2e12RUZGurVHRkbq+PHj5e5z/PjxcvsXFRUpMzNTUVFRZfbJz89Xfn6+azsrK0uSMwECAICGoeR7uyoTTh4LNyUsFovbtmEYZdou1r+89hJJSUl65plnyrS3atWquqUCAAAPy8nJUWhoaKV9PBZuwsPDZbPZyozSZGRklBmdKdGsWbNy+3t5eSksLKzcfaZPn/7/7d1vbFNlGwbw67BuZatV98exNsBoYDLH2MQVSbcZIjPLKiGAE9AMLCGG1HVzrFmiEc0mMcwvaiRik8I2JWJmFt0ckbEVxakYwhhWmlknBgNEt1RAZJtxxPV5PxCbt2958Q/beUZ7/ZKTtOf0rNd91g93nvOcc+B0OkPvg8EgLl26hNTU1Bs2Uf/GlStXMGfOHJw/fz4m5/PEev0Aj0Gs1w/wGLD+2K4fmLpjIITAyMgIjEbjX35WWnOTkJCAgoICeDwerF27NrTe4/Fg9erV193HYrHgwIEDYet6enpgNpsRHx9/3X20Wi20Wm3YujvvvPMm09/Y7bffHrM/aoD1AzwGsV4/wGPA+mO7fmBqjsFfjdj8Seql4E6nE3v37kVzczP8fj9qa2tx7tw52O12ANdGXZ544onQ5+12O86ePQun0wm/34/m5mY0NTWhrq5OVglEREQ0zUidc7NhwwZcvHgRO3bswNDQEHJzc3Hw4EFkZmYCAIaGhsLueWMymXDw4EHU1tZi9+7dMBqN2LVrF8rLy2WVQERERNOM9AnFlZWVqKysvO62t956K2Ld8uXLcfLkySlO9e9otVrU19dHnAaLFbFeP8BjEOv1AzwGrD+26wemxzGQehM/IiIioskm/dlSRERERJOJzQ0RERFFFTY3REREFFXY3BAREVFUYXMzSd58802YTCbMnDkTBQUF+Pzzz2VHUs1nn32GVatWwWg0QlEUdHR0yI6kqsbGRixduhR6vR7p6elYs2YNBgcHZcdSlcvlQl5eXuimXRaLBV1dXbJjSdPY2AhFUbBt2zbZUVTT0NAARVHCloyMDNmxVPXjjz9i48aNSE1NRVJSEu6991709/fLjqWaefPmRfwGFEWBw+FQPQubm0nw3nvvYdu2bdi+fTu++uorPPDAA7BarWH36IlmY2NjyM/PxxtvvCE7ihS9vb1wOBw4duwYPB4P/vjjD5SWlmJsbEx2NNXMnj0bL7/8Mk6cOIETJ05gxYoVWL16NQYGBmRHU11fXx/cbjfy8vJkR1HdokWLMDQ0FFp8Pp/sSKr55ZdfUFRUhPj4eHR1deGbb77BK6+8MuV3xJ9O+vr6wv7/Ho8HALBu3Tr1wwi6affff7+w2+1h67Kzs8Wzzz4rKZE8AER7e7vsGFIFAgEBQPT29sqOIlVycrLYu3ev7BiqGhkZEVlZWcLj8Yjly5eLmpoa2ZFUU19fL/Lz82XHkOaZZ54RxcXFsmNMKzU1NWL+/PkiGAyq/t0cublJV69eRX9/P0pLS8PWl5aW4ssvv5SUimT69ddfAQApKSmSk8gxMTGB1tZWjI2NwWKxyI6jKofDgZUrV+Khhx6SHUWK06dPw2g0wmQy4bHHHsOZM2dkR1JNZ2cnzGYz1q1bh/T0dCxZsgR79uyRHUuaq1ev4p133sGWLVsm/SHVfwebm5t04cIFTExMRDzJfNasWRFPMKfoJ4SA0+lEcXExcnNzZcdRlc/nw2233QatVgu73Y729nbk5OTIjqWa1tZWnDx5Eo2NjbKjSLFs2TLs27cP3d3d2LNnD4aHh1FYWIiLFy/KjqaKM2fOwOVyISsrC93d3bDb7Xj66aexb98+2dGk6OjowOXLl7F582Yp3y/98QvR4n87UyGElG6V5KqqqsKpU6fwxRdfyI6iuoULF8Lr9eLy5ct4//33YbPZ0NvbGxMNzvnz51FTU4Oenh7MnDlTdhwprFZr6PXixYthsVgwf/58vP3223A6nRKTqSMYDMJsNmPnzp0AgCVLlmBgYAAulyvsAdCxoqmpCVarFUajUcr3c+TmJqWlpSEuLi5ilCYQCESM5lB0q66uRmdnJ44cOYLZs2fLjqO6hIQELFiwAGazGY2NjcjPz8frr78uO5Yq+vv7EQgEUFBQAI1GA41Gg97eXuzatQsajQYTExOyI6pOp9Nh8eLFOH36tOwoqjAYDBGN/D333BMzF5b8t7Nnz+Lw4cN48sknpWVgc3OTEhISUFBQEJoV/iePx4PCwkJJqUhNQghUVVXhgw8+wCeffAKTySQ70rQghMD4+LjsGKooKSmBz+eD1+sNLWazGRUVFfB6vYiLi5MdUXXj4+Pw+/0wGAyyo6iiqKgo4hYQ3333HTIzMyUlkqelpQXp6elYuXKltAw8LTUJnE4nNm3aBLPZDIvFArfbjXPnzsFut8uOporR0VF8//33ofc//PADvF4vUlJSMHfuXInJ1OFwOPDuu+/iww8/hF6vD43i3XHHHUhMTJScTh3PPfccrFYr5syZg5GREbS2tuLTTz/FoUOHZEdThV6vj5hjpdPpkJqaGjNzr+rq6rBq1SrMnTsXgUAAL730Eq5cuQKbzSY7mipqa2tRWFiInTt3Yv369Th+/DjcbjfcbrfsaKoKBoNoaWmBzWaDRiOxxVD9+qwotXv3bpGZmSkSEhLEfffdF1OXAR85ckQAiFhsNpvsaKq4Xu0AREtLi+xoqtmyZUvo93/XXXeJkpIS0dPTIzuWVLF2KfiGDRuEwWAQ8fHxwmg0ikceeUQMDAzIjqWqAwcOiNzcXKHVakV2drZwu92yI6muu7tbABCDg4NScyhCCCGnrSIiIiKafJxzQ0RERFGFzQ0RERFFFTY3REREFFXY3BAREVFUYXNDREREUYXNDREREUUVNjdEREQUVdjcEBHh2sNvOzo6ZMcgoknA5oaIpNu8eTMURYlYysrKZEcjolsQny1FRNNCWVkZWlpawtZptVpJaYjoVsaRGyKaFrRaLTIyMsKW5ORkANdOGblcLlitViQmJsJkMqGtrS1sf5/PhxUrViAxMRGpqanYunUrRkdHwz7T3NyMRYsWQavVwmAwoKqqKmz7hQsXsHbtWiQlJSErKwudnZ1TWzQRTQk2N0R0S3jhhRdQXl6Or7/+Ghs3bsTjjz8Ov98PAPjtt99QVlaG5ORk9PX1oa2tDYcPHw5rXlwuFxwOB7Zu3Qqfz4fOzk4sWLAg7DtefPFFrF+/HqdOncLDDz+MiooKXLp0SdU6iWgSSH1sJxGREMJms4m4uDih0+nClh07dgghrj153W63h+2zbNky8dRTTwkhhHC73SI5OVmMjo6Gtn/00UdixowZYnh4WAghhNFoFNu3b/+/GQCI559/PvR+dHRUKIoiurq6Jq1OIlIH59wQ0bTw4IMPwuVyha1LSUkJvbZYLGHbLBYLvF4vAMDv9yM/Px86nS60vaioCMFgEIODg1AUBT/99BNKSkpumCEvLy/0WqfTQa/XIxAI/OuaiEgONjdENC3odLqI00R/RVEUAIAQIvT6ep9JTEz8W38vPj4+Yt9gMPiPMhGRfJxzQ0S3hGPHjkW8z87OBgDk5OTA6/VibGwstP3o0aOYMWMG7r77buj1esybNw8ff/yxqpmJSA6O3BDRtDA+Po7h4eGwdRqNBmlpaQCAtrY2mM1mFBcXY//+/Th+/DiampoAABUVFaivr4fNZkNDQwN+/vlnVFdXY9OmTZg1axYAoKGhAXa7Henp6bBarRgZGcHRo0dRXV2tbqFENOXY3BDRtHDo0CEYDIawdQsXLsS3334L4NqVTK2traisrERGRgb279+PnJwcAEBSUhK6u7tRU1ODpUuXIikpCeXl5Xj11VdDf8tms+H333/Ha6+9hrq6OqSlpeHRRx9Vr0AiUo0ihBCyQxAR3YiiKGhvb8eaNWtkRyGiWwDn3BAREVFUYXNDREREUYVzboho2uPZcyL6JzhyQ0RERFGFzQ0RERFFFTY3REREFFXY3BAREVFUYXNDREREUYXNDREREUUVNjdEREQUVdjcEBERUVRhc0NERERR5T/AEzkGMZECCQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(train_a, label='train accuracy')\n",
    "plt.plot(val_a, label='validation accuracy')\n",
    "\n",
    "plt.title('Training history')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()\n",
    "plt.ylim([0, 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbT4YzHFh1s7"
   },
   "source": [
    "Accuracy of Pos/Neg on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 28664,
     "status": "ok",
     "timestamp": 1695329081374,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L1ZzFERMAQk9",
    "outputId": "56411273-fc8c-476f-f660-21239bd321ad",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.768836291913215\n",
      "F1-Macro:  0.728066740501927\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "test_acc, _, test_f1 = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "\n",
    "print(\"Accuracy: \",test_acc.item())\n",
    "print(\"F1-Macro: \",test_f1.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695329212064,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "d7DaFLMiAWps",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader):\n",
    "  model = model.eval()\n",
    "\n",
    "  review = []\n",
    "  predictions = []\n",
    "\n",
    "  prediction_probs = []\n",
    "  real_values = []\n",
    "\n",
    "\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "\n",
    "      reviews = d[\"review\"]\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      )\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "\n",
    "      probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "      review.extend(reviews)\n",
    "      predictions.extend(preds)\n",
    "\n",
    "      prediction_probs.extend(probs)\n",
    "      real_values.extend(sentiments)\n",
    "\n",
    "\n",
    "  predictions = torch.stack(predictions).cpu()\n",
    "  prediction_probs = torch.stack(prediction_probs).cpu()\n",
    "  real_values = torch.stack(real_values).cpu()\n",
    "\n",
    "  return review, predictions, prediction_probs, real_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "executionInfo": {
     "elapsed": 27254,
     "status": "ok",
     "timestamp": 1695329239311,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "kZhYtki1AYx8",
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(\n",
    "  model,\n",
    "  test_data_loader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695329239312,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "mO1FHn0GAbCU",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix\n",
    "class_names = ['negative', 'positive']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PnBSQFJuh_As"
   },
   "source": [
    "Pos/Neg Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695329239312,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "J9-01snpAcM8",
    "outputId": "4d6f7fff-9a14-49bf-c847-d2e8ba79f5f7",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative       0.78      0.89      0.83      1603\n",
      "    positive       0.75      0.56      0.64       932\n",
      "\n",
      "    accuracy                           0.77      2535\n",
      "   macro avg       0.76      0.73      0.74      2535\n",
      "weighted avg       0.77      0.77      0.76      2535\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_test, y_pred, target_names=class_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
