{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695324378707,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "F7toc08bpdQ1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from transformers import BertModel, BertConfig\n",
    "\n",
    "from transformers import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score\n",
    "import textwrap\n",
    "import math\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1695323483676,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "MLk4JWizs4S9",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/train.csv')\n",
    "df_val = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/val.csv')\n",
    "df_test = pd.read_csv('/home/m_nsu/ICLR/Datasets/GoEmotion/test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "YWiFI0a9QqEa",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_senti = pd.read_excel('/home/m_nsu/ICLR/Datasets/GoEmotion/sentiwords.xlsx')\n",
    "conditions = [\n",
    "    (df_senti['PosScore'] > df_senti['NegScore']),\n",
    "    (df_senti['PosScore'] < df_senti['NegScore']),\n",
    "    (df_senti['PosScore'] == df_senti['NegScore'])\n",
    "    ]\n",
    "\n",
    "values = ['Positive','Negative','Neutral']\n",
    "\n",
    "df_senti = df_senti[['PosScore','NegScore','Word','Definition']]\n",
    "df_senti['Sentiment'] = np.select(conditions, values)\n",
    "df_senti = df_senti.dropna(axis=0)\n",
    "df_senti.drop(columns=['PosScore', 'NegScore'], inplace=True)\n",
    "df_senti = df_senti[['Word', 'Sentiment', 'Definition']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 48,
     "status": "ok",
     "timestamp": 1695323754433,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "vX1DfZpN1W01",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train.dropna(inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "executionInfo": {
     "elapsed": 48,
     "status": "ok",
     "timestamp": 1695323754435,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "q_pTZIR8ltaJ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train.rename(columns={'Text': 'review'}, inplace=True)\n",
    "df_val.rename(columns={'Text': 'review'}, inplace=True)\n",
    "df_test.rename(columns={'Text': 'review'}, inplace=True)\n",
    "\n",
    "\n",
    "df_train.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)\n",
    "df_val.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)\n",
    "df_test.rename(columns={'Mapped Sentiment': 'sentiment'}, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 49,
     "status": "ok",
     "timestamp": 1695323754437,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e0NG9J-oq_wW",
    "outputId": "f1bb3869-364e-4c7a-8c3f-a62ce4b182c7",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>review</th>\n",
       "      <th>Emotions</th>\n",
       "      <th>sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>my favorite food is anything i did not have to...</td>\n",
       "      <td>neutral</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>now if he does off himself everyone will think...</td>\n",
       "      <td>neutral</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>why the fuck is bayless isoing</td>\n",
       "      <td>anger</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>to make her feel threatened</td>\n",
       "      <td>fear</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>dirty southern wankers</td>\n",
       "      <td>annoyance</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              review   Emotions  sentiment\n",
       "0  my favorite food is anything i did not have to...    neutral          3\n",
       "1  now if he does off himself everyone will think...    neutral          3\n",
       "2                     why the fuck is bayless isoing      anger          1\n",
       "3                        to make her feel threatened       fear          1\n",
       "4                             dirty southern wankers  annoyance          1"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 45,
     "status": "ok",
     "timestamp": 1695323754437,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "RN48HyYaC2VD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def filter_rows_by_values(df, col, values):\n",
    "    return df[~df[col].isin(values)]\n",
    "\n",
    "df_train = filter_rows_by_values(df_train,'sentiment',[2,3])\n",
    "df_test = filter_rows_by_values(df_test,'sentiment',[2,3])\n",
    "df_val = filter_rows_by_values(df_val,'sentiment',[2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 45,
     "status": "ok",
     "timestamp": 1695323754438,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Qm9XHUHuuSXq",
    "outputId": "6e7765c0-f186-4b8c-db5a-d954fbb3dc9a",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train['emotion'], map = pd.factorize(df_train['Emotions'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754438,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "TCTYc5frxZfx",
    "tags": []
   },
   "outputs": [],
   "source": [
    "emotion_map = dict(zip(map, range(len(map))))\n",
    "map_emotion = {v: k for k, v in emotion_map.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754439,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "_xxZFLJwxfpy",
    "outputId": "914e4d70-43c6-462e-dc72-954759c555eb",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'gratitude'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emotion_map['gratitude']  # Get Encoded Label from Categorical Label\n",
    "map_emotion[3]  # Get Categorical Label from Encoded Label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 33,
     "status": "ok",
     "timestamp": 1695323754440,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "k-a5Ww9MhpdO",
    "outputId": "ecf8deb0-f932-4d16-a103-9acc16f252ec",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_val['emotion'] = df_val[\"Emotions\"].apply(lambda x: emotion_map[x])\n",
    "df_test['emotion'] = df_test[\"Emotions\"].apply(lambda x: emotion_map[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695323756839,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UmEtmaFNrakz",
    "tags": []
   },
   "outputs": [],
   "source": [
    "MAX_LEN = 200\n",
    "RANDOM_SEED = 42\n",
    "device = torch.device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695323756840,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6oWORty0p8Xo",
    "outputId": "cb12350b-753a-4c61-c7e8-3c9427f000e9",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:1\n"
     ]
    }
   ],
   "source": [
    "PRE_TRAINED_MODEL_NAME = 'bert-base-uncased'\n",
    "config = BertConfig.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "clear_output()\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695323770760,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "OtZt1p7ys7XD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class IMDBDataset(Dataset):\n",
    "\n",
    "  def __init__(self, reviews, sentiments, tokenizer, max_len):\n",
    "    self.reviews = reviews\n",
    "    self.sentiments = sentiments\n",
    "    self.tokenizer = tokenizer\n",
    "    self.max_len = max_len\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.reviews)\n",
    "\n",
    "  def __getitem__(self, item):\n",
    "    review = str(self.reviews[item])\n",
    "    sentiment = self.sentiments[item]\n",
    "\n",
    "\n",
    "    encoding = self.tokenizer.encode_plus(\n",
    "      review,\n",
    "      add_special_tokens=True,\n",
    "      max_length=self.max_len,\n",
    "      return_token_type_ids=False,\n",
    "      padding='max_length',\n",
    "      truncation = True,\n",
    "      return_attention_mask=True,\n",
    "      return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    return {\n",
    "      'review': review,\n",
    "      'input_ids': encoding['input_ids'].flatten(),\n",
    "      'attention_mask': encoding['attention_mask'].flatten(),\n",
    "      'sentiments': torch.tensor(sentiment, dtype=torch.long),\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695325189268,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UuOujQajtL5f",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_data_loader(df, tokenizer, max_len, batch_size):\n",
    "  ds = IMDBDataset(\n",
    "    reviews=df.review.to_numpy(),\n",
    "    sentiments=df['sentiment'].to_numpy(),\n",
    "    tokenizer=tokenizer,\n",
    "    max_len=max_len\n",
    "  )\n",
    "\n",
    "  return DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=8\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695323770761,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "3zzA4eBytOqj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "\n",
    "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695325187523,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AoUfRPy0tQgk",
    "outputId": "4cc31257-8fe1-4cce-883e-8aff4ebf2185",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['review', 'input_ids', 'attention_mask', 'sentiments'])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = next(iter(train_data_loader))\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1695325688712,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Y3Nil-yatUqF",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Classifier(nn.Module):\n",
    "  def __init__(self):\n",
    "    super(Classifier, self).__init__()\n",
    "    self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME,config=config)\n",
    "    self.FC = nn.Linear(config.hidden_size,2, bias=False)\n",
    "\n",
    "\n",
    "  def forward(self, input_ids, attention_mask):\n",
    "    with torch.no_grad():\n",
    "      pooled_output = self.bert(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_dict = False\n",
    "      )\n",
    "    pooled_output = torch.mean(pooled_output[0], dim=1) # Taking Averge pooled last layer embedding\n",
    "\n",
    "    binary_out = self.FC(pooled_output)\n",
    "    \n",
    "    return binary_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "id": "HWZ37gsztWzL",
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = Classifier()\n",
    "model = model.to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1695325221207,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWy57v2CxxCM",
    "tags": []
   },
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    if name.startswith('bert'):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1695325221207,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e5iLu13CYlux",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#for name, param in model.named_parameters():\n",
    "#    print(name, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 21,
     "status": "ok",
     "timestamp": 1695325221208,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "nZpqz6yDtYZ4",
    "outputId": "c0ba4f97-cad5-46b8-e1ff-e818ff5b4035",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 200])\n",
      "torch.Size([32, 200])\n"
     ]
    }
   ],
   "source": [
    "input_ids = data['input_ids'].to(device)\n",
    "attention_mask = data['attention_mask'].to(device)\n",
    "sentiments = data['sentiments'].to(device)\n",
    "\n",
    "print(input_ids.shape) # batch size x seq length\n",
    "print(attention_mask.shape) # batch size x seq length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1695325221208,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NwvA-zp7vqc1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#del test\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "executionInfo": {
     "elapsed": 603,
     "status": "ok",
     "timestamp": 1695325221800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "-9Z37OXOtb0q",
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs = model(input_ids, attention_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1695325221800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NHFM3QUhg3n0",
    "outputId": "602bd955-db78-44a6-931e-e5e901bb1856",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.2859, -0.1522],\n",
       "        [-0.0713, -0.1591],\n",
       "        [-0.3120, -0.1357],\n",
       "        [-0.0890, -0.0393],\n",
       "        [-0.0982, -0.0973],\n",
       "        [-0.0556, -0.1899],\n",
       "        [-0.0084, -0.1194],\n",
       "        [-0.0559, -0.0935],\n",
       "        [-0.0267, -0.1628],\n",
       "        [-0.0491, -0.1007],\n",
       "        [ 0.0543, -0.0414],\n",
       "        [-0.0963,  0.0608],\n",
       "        [-0.0007, -0.0968],\n",
       "        [-0.0452, -0.3013],\n",
       "        [-0.1103, -0.1800],\n",
       "        [-0.0856, -0.0525],\n",
       "        [ 0.0634, -0.0155],\n",
       "        [-0.1064, -0.2316],\n",
       "        [-0.2929,  0.0091],\n",
       "        [-0.0547, -0.1934],\n",
       "        [-0.0333, -0.1818],\n",
       "        [ 0.0230, -0.0695],\n",
       "        [-0.2055, -0.1096],\n",
       "        [-0.2010, -0.1768],\n",
       "        [-0.1672, -0.2497],\n",
       "        [-0.0983, -0.0677],\n",
       "        [-0.2429, -0.3640],\n",
       "        [-0.1151, -0.0951],\n",
       "        [-0.2005, -0.1738],\n",
       "        [-0.0606, -0.1840],\n",
       "        [-0.0560,  0.0174],\n",
       "        [-0.0420, -0.1475]], device='cuda:1', grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1022,
     "status": "ok",
     "timestamp": 1695325845248,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cjMVWA5a_6lf",
    "outputId": "bae8764e-0ce8-4118-e676-267b2c3ac06b",
    "tags": []
   },
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=0.001)\n",
    "total_steps = len(train_data_loader) * EPOCHS\n",
    "\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "  optimizer,\n",
    "  num_warmup_steps=math.floor((1./5)*total_steps),\n",
    "  num_training_steps=total_steps\n",
    ")\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695325845969,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cLFDb4pzbx9W",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_epoch(\n",
    "  model,\n",
    "  data_loader,\n",
    "  loss_fn,\n",
    "  optimizer,\n",
    "  device,\n",
    "  scheduler,\n",
    "  n_examples\n",
    "):\n",
    "  model = model.train()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  for d in data_loader:\n",
    "    input_ids = d[\"input_ids\"].to(device)\n",
    "    attention_mask = d[\"attention_mask\"].to(device)\n",
    "    sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "    outputs = model(\n",
    "      input_ids=input_ids,\n",
    "      attention_mask=attention_mask\n",
    "    ).to(device)\n",
    "\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    loss = loss_fn(outputs, sentiments)\n",
    "\n",
    "    correct_predictions += torch.sum(preds == sentiments)\n",
    "    losses.append(loss.item())\n",
    "\n",
    "\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695325845971,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "z4GAdIawtUue",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def eval_model(model, data_loader, loss_fn, device, n_examples, on_new=False):\n",
    "  model = model.eval()\n",
    "\n",
    "  losses = []\n",
    "  f1s = []\n",
    "\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      ).to(device)\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      loss = loss_fn(outputs, sentiments)\n",
    "\n",
    "      correct_predictions += torch.sum(preds == sentiments)\n",
    "      losses.append(loss.item())\n",
    "\n",
    "      f1s.append(f1_score(sentiments.cpu(), preds.cpu(), average='macro'))\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses), np.mean(f1s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "executionInfo": {
     "elapsed": 3203929,
     "status": "ok",
     "timestamp": 1695329051836,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "IqdIHJsrANr0",
    "outputId": "ee79ec9a-a033-47e0-bc8b-a7b71698465c",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "----------\n",
      "Train loss 0.5377342011151689 accuracy 0.7161782416700121\n",
      "Val   loss 0.41229210209242906 accuracy 0.8143593811979374\n",
      "\n",
      "Epoch 2/8\n",
      "----------\n",
      "Train loss 0.40848803847884674 accuracy 0.8138297872340425\n",
      "Val   loss 0.37009768697279916 accuracy 0.8369694565648552\n",
      "\n",
      "Epoch 3/8\n",
      "----------\n",
      "Train loss 0.38917122250288294 accuracy 0.8215576073865918\n",
      "Val   loss 0.36110645268536823 accuracy 0.8397461324871083\n",
      "\n",
      "Epoch 4/8\n",
      "----------\n",
      "Train loss 0.37802422955560455 accuracy 0.8277800080289041\n",
      "Val   loss 0.3558001193819167 accuracy 0.8397461324871083\n",
      "\n",
      "Epoch 5/8\n",
      "----------\n",
      "Train loss 0.37259876247298085 accuracy 0.8277800080289041\n",
      "Val   loss 0.3533075508437579 accuracy 0.8437128123760412\n",
      "\n",
      "Epoch 6/8\n",
      "----------\n",
      "Train loss 0.36840712811265675 accuracy 0.8313930148534725\n",
      "Val   loss 0.35199704294717765 accuracy 0.8452994843316144\n",
      "\n",
      "Epoch 7/8\n",
      "----------\n",
      "Train loss 0.36801098012044 accuracy 0.8336511441188278\n",
      "Val   loss 0.35112780709809893 accuracy 0.8441094803649346\n",
      "\n",
      "Epoch 8/8\n",
      "----------\n",
      "Train loss 0.3652148567510837 accuracy 0.8340024086712164\n",
      "Val   loss 0.35079684514033643 accuracy 0.8445061483538279\n",
      "\n",
      "CPU times: user 43min 30s, sys: 8.58 s, total: 43min 38s\n",
      "Wall time: 43min 43s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_a = []\n",
    "train_l = []\n",
    "val_a = []\n",
    "val_l = []\n",
    "best_accuracy = 0\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "  print(f'Epoch {epoch + 1}/{EPOCHS}')\n",
    "  print('-' * 10)\n",
    "\n",
    "  train_acc, train_loss = train_epoch(\n",
    "    model,\n",
    "    train_data_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device,\n",
    "    scheduler,\n",
    "    len(df_train)\n",
    "  )\n",
    "\n",
    "  print(f'Train loss {train_loss} accuracy {train_acc}')\n",
    "\n",
    "  val_acc, val_loss, val_f1 = eval_model(\n",
    "    model,\n",
    "    val_data_loader,\n",
    "    loss_fn,\n",
    "    device,\n",
    "    len(df_val)\n",
    "  )\n",
    "\n",
    "  print(f'Val   loss {val_loss} accuracy {val_acc}')\n",
    "  print()\n",
    "\n",
    "  train_a.append(train_acc)\n",
    "  train_l.append(train_loss)\n",
    "  val_a.append(val_acc)\n",
    "  val_l.append(val_loss)\n",
    "\n",
    "  if val_acc > best_accuracy:\n",
    "    torch.save(model.state_dict(), 'baseline_bert_best_model_state.bin')\n",
    "    best_accuracy = val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "executionInfo": {
     "elapsed": 51,
     "status": "ok",
     "timestamp": 1695329137842,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "FowMSU5U7SDQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_a = [i.item() for i in train_a]\n",
    "train_l = [i.item() for i in train_l]\n",
    "val_a = [i.item() for i in val_a]\n",
    "val_l = [i.item() for i in val_l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 480
    },
    "executionInfo": {
     "elapsed": 2888,
     "status": "ok",
     "timestamp": 1695329143450,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aUQbxyTEAPhM",
    "outputId": "b103ca88-1886-4f16-ea21-5b966a08fe7c",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHFCAYAAAAOmtghAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOzde1xUdf7H8feZAQa5lqiIeSNvhaaVmKJrlqWmZdrupnZTstplrdSoLLf9lZpFtWmmJt1EM9uy1mxt1zLKSjfdShOr1cwMwwsuYsnFCzAz5/cHMDECchEYOL6ej8c8Zs73fM85nzNOzZvvOXOOYZqmKQAAAIuw+boAAACAukS4AQAAlkK4AQAAlkK4AQAAlkK4AQAAlkK4AQAAlkK4AQAAlkK4AQAAlkK4AQAAlkK4ASzEMIxqPT755JM62d6JEydkGIaeeOKJWi3fr18/XXXVVXVSS021bt1av//976vs9/7778swDP3nP/+p0frnz5+v5cuX17Y8AKfBz9cFAKg7mzZt8pp+9NFH9fHHH2vdunVe7TExMXWyPYfDoU2bNql9+/a1Wn7x4sWy2+11Ukt9iYuL06ZNm9SjR48aLTd//nx17txZN998cz1VBqAyhBvAQvr16+c13bJlS9lstnLtlSksLJTdbq924DAMo9rrrkj37t1rvWxDCQ8PP619rEsul0sul0sBAQG+LgVo1DgsBZyhSg+3rFixQpMnT1ZUVJQCAwO1d+9eZWZmKiEhQeeff76Cg4MVGRmpK6+8stzIUEWHpZ5//nkZhqHPPvtMd9xxhyIiItSiRQtdf/31+t///ue1/MmHpb777jsZhqEFCxboySefVIcOHRQSEqIBAwZoy5Yt5fZh0aJF6ty5sxwOhy644AK99dZbGjdunM4777xqvw/vvvuuLrzwQjVr1kwxMTHlDiVVdFjq+++/1/XXX6+oqCg5HA61bt1aQ4YM0X//+19JxYe8du/erbVr13oOBZatKT09XTfccINatmwph8OhmJgYzZ8/X2XvY1z6XsybN08zZsxQx44dFRAQoDVr1igkJERTpkwpty87d+6UzWbTggULqr3/gBUxcgOc4e69915deumlevnll+V2u3X22WcrIyND/v7+mjlzpiIjI5WXl6e33npLAwcO1IYNGxQXF1fleidMmKBrr71Wr7/+utLT0zVt2jTdeuutWrNmTZXLzp07VxdccIEWLFggl8ulhx56SMOHD1d6erqCg4MlFR/2mTJlisaNG6f58+fr559/1vTp01VUVKRmzZpVa9+//PJL7dy5Uw8++KBatGih5ORk3XLLLeratasuueSSCpcxTVNXXXWVHA6Hnn76abVr107Z2dnasGGDjhw5Iklas2aNRo0apbZt2+qZZ56RJE9NmZmZiouLk2EYSkpKUtu2bfXOO+9oypQp2rNnj+bOneu1vaeffloxMTGaO3euQkJCFBMTo/Hjx+uVV17R448/7nk/JOm5555TcHCwJkyYUK39ByzLBGBZEyZMMIODgyuc995775mSzKFDh1a5HqfTaRYVFZkDBgwwb7jhBk/78ePHTUlmUlKSpy05OdmUZCYmJnqtY9asWaYk8+eff/a09e3b1xw2bJhneseOHaYkMzY21nS73Z729evXm5LMVatWmaZpmoWFhWZERIQ5aNAgr2388MMPpt1uN7t161blPkVGRprBwcHmgQMHPG35+flmaGioOWXKFE9b6fu0adMm0zRNc9++faYk8/nnnz/l+jt16uS1b6WmTp1qGoZhpqWlebXfeuutps1mM9PT073ei/PPP990Op1efXfs2GEahmEmJyd72vLy8sywsDDzT3/6U5X7Dlgdh6WAM9zvfve7cm2maWrBggW66KKLFBgYKD8/P/n7++uzzz7Tjh07qrXea6+91mu6Z8+ekqSMjIwql73mmmtkGEa5ZX/66SdJ0rfffqvDhw9rzJgxXst16tRJffr0qVZ9ktSnTx9FRUV5poODg9WpUyfPdirSunVrtW/fXo8//rieffZZbdu2TW63u9rbXLdunS666CL16tXLqz0+Pl5ut7vcL9lGjx5d7hyo8847T0OGDNFzzz3naVu2bJlyc3N15513VrsWwKoIN8AZruyXe6mkpCRNnjxZAwcO1Ntvv63PP/9cX375pQYPHqzjx49Xa70RERFe0w6HQ5KqtXxVyx4+fFiSFBkZWW7Zitqqu53SbZ2qRrvdro8//liXX365HnvsMV144YWKjIxUYmKijh49WuU2Dx8+XOF73qZNG8/8sirqK0lTpkzRt99+q/Xr10sqPiR1+eWXN4mTtIH6xjk3wBmu7AhJqeXLl+uqq67S/PnzvdpzcnIaqqxTKg0lJ5+gLEkHDx6s9+2fe+65Wrp0qaTiE3/feOMNPfroo3K73Zo3b94pl42IiFBmZma59gMHDkiSWrRo4dVe0b+PJA0fPlxdunTRwoUL5XQ6tX37ds2aNasWewNYDyM3AMoxDMMzWlJq8+bN+uqrr3xUkbcePXqoefPmWrFihVf77t27tXnz5gat5bzzztOMGTPUtWtXr/enshGgK664QmlpaZ5fVpVatmyZbDabLrvssmpt1zAM3X333Vq1apUeeeQRtW3bVqNHjz6tfQGsgnADoJxrrrlG7777rmbPnq1169Zp4cKFuvrqq9WxY0dflyZJ8vf31yOPPKL169frhhtu0Hvvvafly5dr2LBhatOmjWy2+vtf2xdffKHLL79czz33nNauXat169bpwQcf1M6dOzVkyBBPvwsuuECbN2/W3//+d23evNkTZu6//361bNlSw4YNU0pKitauXas777xTixcv1tSpU9WhQ4dq1xIfH6+goCD9+9//VkJCQqO/ICLQUDgsBaCcGTNmqLCwUIsWLdJjjz2mHj16aMmSJVq2bJnS0tJ8XZ4kafLkybLb7Zo7d67efvttnXvuuZo5c6ZeffVV5ebm1tt227Ztq/bt22vBggXat2+fbDabOnXqpPnz52vSpEmefo899piys7N16623Kj8/X926ddN3332nqKgobdq0SdOnT9f999+vvLw8derUSfPmzdPkyZNrVEtoaKhGjBiht99+W3fccUdd7yrQZBmmWeaqUQDQhB0+fFhdunTRzTffXO58ISs6fvy42rdvr+HDh2vZsmW+LgdoNBi5AdAkZWRkaO7cuRo0aJCaN2+u9PR0zZkzRwUFBbr77rt9XV69ysrK0vfff68XXnhBv/zyi6ZNm+brkoBGhXADoEkKDAzUrl279Prrr+vnn39WSEiI+vfvr6VLl6pLly6+Lq9evf322/rTn/6kc845Ry+99FKNb+oJWB2HpQAAgKX49NdS69ev18iRI9WmTRsZhqF33nmnymU+/fRT9e7dW4GBgTr33HP1/PPPN0ClAACgqfBpuDl69Kh69eqlhQsXVqt/enq6RowYoYEDB2rr1q3685//rMmTJ2vlypX1XCkAAGgqGs1hKcMwtGrVqlNehOqBBx7Q6tWrve5tk5CQoG3btmnTpk0NUSYAAGjkmtQJxZs2bdLQoUO92oYNG6bFixerqKhI/v7+5ZYpKChQQUGBZ9rtduvnn39WREREpZc1BwAAjYtpmsrLy6vWhTqbVLg5ePBguZviRUZGyul0Kjs7u9IbAM6cObOhSgQAAPVo7969atu27Sn7NKlwI5W/iVzpUbXKRmGmT5+uxMREz3ROTo7at2+vvXv3KiwsrP4KBQAAdSY3N1ft2rVTaGholX2bVLhp3bp1uTv+ZmVlyc/Pz3OX4JM5HI5yNwCUpLCwMMINAABNTHVOKWlSN86Mi4tTamqqV9sHH3yg2NjYCs+3AQAAZx6fhpv8/HylpaV5bsSXnp6utLQ0ZWRkSCo+pDR+/HhP/4SEBP30009KTEzUjh07lJKSosWLF+u+++7zSf0AAKDx8elhqc2bN+vyyy/3TJeeGzNhwgQtXbpUmZmZnqAjSdHR0VqzZo3uuecePffcc2rTpo3mz5+v3/3udw1eOwAAaJwazXVuGkpubq7Cw8OVk5PDOTcAmiSXy6WioiJflwHUuYCAgEp/5l2T7+8mdUIxAJzJTNPUwYMHdeTIEV+XAtQLm82m6OhoBQQEnNZ6CDcA0ESUBptWrVopKCiIC5HCUtxutw4cOKDMzEy1b9/+tD7fhBsAaAJcLpcn2FR26QugqWvZsqUOHDggp9N5Wr+CblI/BQeAM1XpOTZBQUE+rgSoP6WHo1wu12mth3ADAE0Ih6JgZXX1+SbcAAAASyHcAACalI4dO2revHm+LgONGCcUAwDq1WWXXaYLL7ywzgLJl19+qeDg4DpZF6yJcAMA8DnTNOVyueTnV/XXUsuWLRugooZVk/1H1TgsBQCoN/Hx8fr000/17LPPyjAMGYahPXv26JNPPpFhGFq7dq1iY2PlcDi0YcMG7d69W6NGjVJkZKRCQkLUp08fffjhh17rPPmwlGEYevnll3XdddcpKChIXbp00erVq09Z1/LlyxUbG6vQ0FC1bt1aN954o7Kysrz6/Pe//9XVV1+tsLAwhYaGauDAgdq9e7dnfkpKirp37y6Hw6GoqCjdddddkqQ9e/bIMAzPfRMl6ciRIzIMQ5988okkndb+FxQUaNq0aWrXrp0cDoe6dOmixYsXyzRNde7cWU8//bRX/2+//VY2m82rdqsj3ABAE2Wapo4VOn3yqO6de5599lnFxcXpjjvuUGZmpjIzM9WuXTvP/GnTpikpKUk7duxQz549lZ+frxEjRujDDz/U1q1bNWzYMI0cOdLrPoMVmTlzpsaMGaOvv/5aI0aM0E033aSff/650v6FhYV69NFHtW3bNr3zzjtKT09XfHy8Z/7+/ft16aWXKjAwUOvWrdOWLVs0ceJEOZ1OSVJycrLuvPNO/eEPf9A333yj1atXq3PnztV6T8qqzf6PHz9eb7zxhubPn68dO3bo+eefV0hIiAzD0MSJE7VkyRKvbaSkpGjgwIHq1KlTjetrqhj/AoAm6niRSzEPr/XJtrfPGqaggKq/QsLDwxUQEKCgoCC1bt263PxZs2ZpyJAhnumIiAj16tXLMz179mytWrVKq1ev9oyMVCQ+Pl433HCDJOnxxx/XggUL9MUXX+iqq66qsP/EiRM9r88991zNnz9fl1xyifLz8xUSEqLnnntO4eHheuONNzwXk+vatatXXffee6+mTJniaevTp09Vb0c5Nd3/77//Xm+++aZSU1N15ZVXeuovdeutt+rhhx/WF198oUsuuURFRUVavny5/vrXv9a4tqaMkRsAgM/ExsZ6TR89elTTpk1TTEyMzjrrLIWEhOi7776rcuSmZ8+entfBwcEKDQ0td5iprK1bt2rUqFHq0KGDQkNDddlll0mSZztpaWkaOHBghVfJzcrK0oEDB3TFFVdUdzcrVdP9T0tLk91u16BBgypcX1RUlK6++mqlpKRIkv75z3/qxIkTuv7660+71qaEkRsAaKKa+du1fdYwn227Lpz8q6f7779fa9eu1dNPP63OnTurWbNm+v3vf6/CwsJTrufkEGIYhtxud4V9jx49qqFDh2ro0KFavny5WrZsqYyMDA0bNsyznWbNmlW6rVPNk+S5q3XZQ3eV3cW9pvtf1bYl6fbbb9ctt9yiZ555RkuWLNHYsWPPuCtbE24AoIkyDKNah4Z8LSAgoNqX09+wYYPi4+N13XXXSZLy8/O1Z8+eOq3nu+++U3Z2tp544gnP+T+bN2/26tOzZ0+98sorKioqKhecQkND1bFjR3300Ue6/PLLy62/9NdcmZmZuuiiiyTJ6+TiU6lq/y+44AK53W59+umnnsNSJxsxYoSCg4OVnJys9957T+vXr6/Wtq2Ew1IAgHrVsWNHff7559qzZ4+ys7MrHVGRpM6dO+vtt99WWlqatm3bphtvvPGU/Wujffv2CggI0IIFC/Tjjz9q9erVevTRR7363HXXXcrNzdW4ceO0efNm7dq1S6+++qp27twpSZoxY4bmzJmj+fPna9euXfrqq6+0YMECScWjK/369dMTTzyh7du3a/369frLX/5Srdqq2v+OHTtqwoQJmjhxoudE6E8++URvvvmmp4/dbld8fLymT5+uzp07Ky4u7nTfsiaHcAMAqFf33Xef7Ha7YmJiPIeAKvPMM8/o7LPPVv/+/TVy5EgNGzZMF198cZ3W07JlSy1dulRvvfWWYmJi9MQTT5T7+XRERITWrVun/Px8DRo0SL1799ZLL73kGcWZMGGC5s2bp0WLFql79+665pprtGvXLs/yKSkpKioqUmxsrKZMmaLZs2dXq7bq7H9ycrJ+//vfa9KkSTrvvPN0xx136OjRo159brvtNhUWFnqdOH0mMczq/p7PInJzcxUeHq6cnByFhYX5uhwAqJYTJ04oPT1d0dHRCgwM9HU5aOQ+++wzXXbZZdq3b58iIyN9XU61nepzXpPv78Z/sBYAAFRLQUGB9u7dq//7v//TmDFjmlSwqUsclgIAwCJef/11devWTTk5OXrqqad8XY7PEG4AALCI+Ph4uVwubdmyReecc46vy/EZwg0AALAUwg0AALAUwg0AALAUwg0AALAUwg0AALAUwg0AALAUwg0AoNHr2LGj5s2b55k2DEPvvPNOpf337NkjwzCqfcPK+l4PGhZXKAYANDmZmZk6++yz63Sd8fHxOnLkiFdoateunTIzM9WiRYs63RbqF+EGANDktG7dukG2Y7fbG2xbjU1RUZHnRqFNDYelAAD15oUXXtA555wjt9vt1X7ttddqwoQJkqTdu3dr1KhRioyMVEhIiPr06aMPP/zwlOs9+bDUF198oYsuukiBgYGKjY3V1q1bvfq7XC7ddtttio6OVrNmzdStWzc9++yznvkzZszQK6+8on/84x8yDEOGYeiTTz6p8LDUp59+qksuuUQOh0NRUVF68MEH5XQ6PfMvu+wyTZ48WdOmTVPz5s3VunVrzZgx45T78+WXX2rIkCFq0aKFwsPDNWjQIH311VdefY4cOaI//OEPioyMVGBgoHr06KF//vOfnvmfffaZBg0apKCgIJ199tkaNmyYfvnlF0nlD+tJ0oUXXuhVl2EYev755zVq1CgFBwdr9uzZVb5vpVJSUtS9e3fPe3LXXXdJkiZOnKhrrrnGq6/T6VTr1q2VkpJyyvfkdDByAwBNlWlKRcd8s23/IMkwqux2/fXXa/Lkyfr44491xRVXSJJ++eUXrV27Vu+++64kKT8/XyNGjNDs2bMVGBioV155RSNHjtTOnTvVvn37Krdx9OhRXXPNNRo8eLCWL1+u9PR0TZkyxauP2+1W27Zt9eabb6pFixbauHGj/vCHPygqKkpjxozRfffdpx07dig3N1dLliyRJDVv3lwHDhzwWs/+/fs1YsQIxcfHa9myZfruu+90xx13KDAw0CsovPLKK0pMTNTnn3+uTZs2KT4+XgMGDNCQIUMq3Ie8vDxNmDBB8+fPlyTNmTNHI0aM0K5duxQaGiq3263hw4crLy9Py5cvV6dOnbR9+3bZ7XZJUlpamq644gpNnDhR8+fPl5+fnz7++GO5XK4q37+yHnnkESUlJemZZ56R3W6v8n2TpOTkZCUmJuqJJ57Q8OHDlZOTo88++0ySdPvtt+vSSy9VZmamoqKiJElr1qxRfn6+Z/n6QLgBgKaq6Jj0eBvfbPvPB6SA4Cq7NW/eXFdddZX+9re/ecLNW2+9pebNm3ume/XqpV69enmWmT17tlatWqXVq1d7RgBO5bXXXpPL5VJKSoqCgoLUvXt37du3T3/60588ffz9/TVz5kzPdHR0tDZu3Kg333xTY8aMUUhIiJo1a6aCgoJTHoZatGiR2rVrp4ULF8owDJ133nk6cOCAHnjgAT388MOy2YoPiPTs2VOPPPKIJKlLly5auHChPvroo0rDzeDBg72mX3jhBZ199tn69NNPdc011+jDDz/UF198oR07dqhr166SpHPPPdfT/6mnnlJsbKwWLVrkaevevXuV793JbrzxRk2cONGr7VTvm1T873Xvvfd6Bco+ffpIkvr3769u3brp1Vdf1bRp0yRJS5Ys0fXXX6+QkJAa11ddHJYCANSrm266SStXrlRBQYGk4jAybtw4z6jD0aNHNW3aNMXExOiss85SSEiIvvvuO2VkZFRr/Tt27FCvXr0UFBTkaYuLiyvX7/nnn1dsbKxatmypkJAQvfTSS9XeRtltxcXFySgzajVgwADl5+dr3759nraePXt6LRcVFaWsrKxK15uVlaWEhAR17dpV4eHhCg8PV35+vqe+tLQ0tW3b1hNsTlY6cnO6YmNjy7Wd6n3LysrSgQMHTrnt22+/3TMalpWVpX/961/lAlRdY+QGAJoq/6DiERRfbbuaRo4cKbfbrX/961/q06ePNmzYoLlz53rm33///Vq7dq2efvppde7cWc2aNdPvf/97FRYWVmv9pmlW2efNN9/UPffcozlz5iguLk6hoaH661//qs8//7za+1G6LeOkw3Gl2y/bfvKJuIZhlDvvqKz4+HgdOnRI8+bNU4cOHeRwOBQXF+d5D5o1a3bKuqqab7PZyr1PRUVF5foFB3uPxlX1vlW1XUkaP368HnzwQW3atEmbNm1Sx44dNXDgwCqXOx2EGwBoqgyjWoeGfK1Zs2b67W9/q9dee00//PCDunbtqt69e3vmb9iwQfHx8bruuuskFZ+Ds2fPnmqvPyYmRq+++qqOHz/u+bL9z3/+49Vnw4YN6t+/vyZNmuRp2717t1efgICAKs9RiYmJ0cqVK71CzsaNGxUaGqpzzjmn2jWfbMOGDVq0aJFGjBghSdq7d6+ys7M983v27Kl9+/bp+++/r3D0pmfPnvroo4+8DiGV1bJlS2VmZnqmc3NzlZ6eXq26TvW+hYaGqmPHjvroo490+eWXV7iOiIgIjR49WkuWLNGmTZt06623Vrnd08VhKQBAvbvpppv0r3/9SykpKbr55pu95nXu3Flvv/220tLStG3bNt14442nHOU42Y033iibzabbbrtN27dv15o1a/T000+X28bmzZu1du1aff/99/q///s/ffnll159OnbsqK+//lo7d+5UdnZ2hSMbkyZN0t69e3X33Xfru+++0z/+8Q898sgjSkxM9JxvUxudO3fWq6++qh07dujzzz/XTTfd5DUqMmjQIF166aX63e9+p9TUVKWnp+u9997T+++/L0maPn26vvzyS02aNElff/21vvvuOyUnJ3sC0uDBg/Xqq69qw4YN+vbbbzVhwgTPYcGq6qrqfZsxY4bmzJmj+fPna9euXfrqq6+0YMECrz633367XnnlFe3YscPzK7n6RLgBANS7wYMHq3nz5tq5c6duvPFGr3nPPPOMzj77bPXv318jR47UsGHDdPHFF1d73SEhIXr33Xe1fft2XXTRRXrooYf05JNPevVJSEjQb3/7W40dO1Z9+/bV4cOHvUYjJOmOO+5Qt27dPOeXlP7ip6xzzjlHa9as0RdffKFevXopISFBt912m/7yl7/U4N0oLyUlRb/88osuuugi3XLLLZo8ebJatWrl1WflypXq06ePbrjhBsXExGjatGmekaauXbvqgw8+0LZt23TJJZcoLi5O//jHP+TnV3yAZvr06br00kt1zTXXaMSIERo9erQ6depUZV3Ved8mTJigefPmadGiRerevbuuueYa7dq1y6vPlVdeqaioKA0bNkxt2tT/SfCGWZ2DlRaSm5ur8PBw5eTkKCwszNflAEC1nDhxQunp6YqOjlZgYKCvywFq5NixY2rTpo1SUlL029/+ttJ+p/qc1+T7m3NuAABAvXC73Tp48KDmzJmj8PBwXXvttQ2yXcINAACoFxkZGYqOjlbbtm21dOlSz2Gy+ka4AQAA9aJjx47V+ql+XeOEYgAAYCmEGwBoQs6w34DgDFNXn2/CDQA0AaVXvD12zEc3ygQaQOkVmatzDZ5T4ZwbAGgC7Ha7zjrrLM/9iYKCgsrdBgBoytxutw4dOqSgoKDTPvGYcAMATUTp3apPdQNGoCmz2Wxq3779aQd3wg0ANBGGYSgqKkqtWrWq8NYAQFMXEBBwWrexKEW4AYAmxm63n/Y5CUCNmKbkKpScBcUPV8Gvr50nyrfZ/KTzRvisXMINAACNlWlKbqd3kPAKFqVthb+GDK+gcUJyFlbcx6utsOKQ4pk+UbO6Q9sQbgAAaFBut+QuKg4OrpJnz+siye0q89opuZyVvK5gWZezJBCUBoYahoiT28zq3yG9wdgDJL9Ayc8h2R3Fz6UPu0MKbuHT8gg3QG2UDtEWHpWKjkmFx6Sio8XPbqdkuor/h2Saxf+TNN0ljzKv3e5K2sssW2H7SY9y6zcraa9ouzWpyaxerVLxkLTNXvLsV8n0yY8q+ttruox/Jds8xTL2SpZpzL9KMs0y/zYlD5007ZlvVjHf/eu/4SnnmxWvt+xnx1Xyhe8uKvnSd5V57TwpGBSd1L8my55iPacKKI0xMFSHzb9MkAgsEzICvKZNP4fctgC57A65bP5yGcXPTiNARbaA4mf5q8jwV6ECVGj4q8D0V4H8VWD66YTprxOmnwpMfx03/XXCtOuY21/HXHYdd9lV4JYKnS4VOt0qdLmLn51uFRxzq8jlVvOiAL3lw7eJcAPrMs3iv3yKjpUJIfklQeRYxcHEq/3k+Sf1MV2+3kM0FMNWSSDyryAwlZwLU9tg4BU8qtEHdcY07L8GacNPpt3f8+9qlvx7m57p4iBslnwGyrabhr0kVATIafjLaXOUBIkAFRn+KpK/Co0AFZp+OlHyXCD/4kDh9tMJ+euE266jbn8dd/vpmNtPx1x2HfOEipIw4TKLA0bBr+GiNGwUueriYnhuSQUlj5rJOe6sg+3XHuEGvmWaUtHxU4SNikJHSUg5ZTApaW+I//nb/KWAICkgRPJvVvyXk2HzftjsZabtxSMBFbbbiudV2G6TbLZK2ssse8p5FbVXVWsNa1LJaJXb9etf1p6Hq8xf1c4K+lQy7aqsf2XLFFW9zrLTpeuvLLCa7uKROldh/X+eGoApQ6ZsMg2j5LUh07D9+lo2uVVmngy5jZI2s2S65OF5bQSo59EAACAASURBVBpyyya3JLdscsquIvnJJZuc8pNTdjllk9Msfu0y7Coy7SXt9pL+drnMkmfZVWSWLmtTkcr0N+2/tpf0L31dKD+5TO/tF1WwfWdJXaXbcsoms9Fd17ao5HH6/O2GAuw2BfiVedhtCvCzK8BunNRW2l782lFuXpnXFbX52RQU4NsT3gk3daXohLTrA0kladlzCWnT+7XXPFUyr4J1nGpepev3wbbdruqFjsIybSpTU32xO4oDiH9wyXOQFBBc8lymPSC46j4nz7f713/9qBHTNOV0myp0Fg+Rew2du9wqcpoqdLlU6DRLpkvmF7lU5CxSUVGRXM4iOYuK5HIVyVlUKKfTKZfTKZerUGZJu8vllNtZJLe75NnllOlyyumWnKbkcpsqMg053UZJm6kid/F0kduU0y0VmaXB4NegUPxs824rCQ+mVPJ8UrgoDSwl88uGE3eZ9UmN+BBbE1V61NJQ8c/1bYbkX8mXv6NcuCgOEv52o3ieV3+7VxhxVBIkKg0bpdN2m2y2M+vfnXBTV04ckd68xddVNG1+gb8GBk+oqChcVNF+8rL+QcXDzHWowOlS/gmn8n4pVH7BMeWeKFLeCafyTzjlMk3P/+SKn0seJV8qnv8Rlp0vo+S5dH7Z6QrWU9G8MssUb+Dk9VZzGyXtOnm9ldT5a7/i9ZimqSKXWRIkygeIIpdbBSXD56XB49cA8utyRSXLFboq7ltQ+toTVoq3UVCyjfq9BZO95FG//O2G/Gw2+dkM2cu89rcZ8rcbsttK2uyG/GyG/Oy2krbi13624j7FfYuni+eVWbZk3f62KpYt2UZx26n72qrxOfq1vfLP48nLeqYrmV/6kTz5v4WTP98yvD+zJYtV+vn3rrXsf49nVmBoSgg3dcUeILXrV/za84E3vF97zVMl807uV9k6qtuvgm3V9XbLzjOMk0JHJaMdFQUTW/1/Wbjdpo4WOpV3oviRX1Ck3JJQUtxWpPyCX+fnlYaWgl/n5Z5wqtDJuQ5NiWGo3F+2/ic9e4bm7bZyf3Wf3NfhZ/MM8/uf9Fezv700CFQcCLyCRgUBxW4rDhpn2l/aQF0i3NSVoObSbWt9XYWlnShyeYJHfknwyC0bPE44lVfwayApG0pKl8kvdNbpX/RBAXaFBvopNNBfIQ4/hQb6yW4zis//LOlTepfb4jbz1x+3lL6Wio/wlZk2TbPkuXR+yToqXE8l2yizHp08XcF6VG67ZddTvJKK6tNJ25Ghki//MoGgJAQ4PG3GScGifIhweMKC4Rm2r6hvxcsbctjt8vcrDiB+9sZ2LgWA+kS4Qb1zu03lF3qHkryTgoinvaD8CErpqEqhq+5GS/zthlcgKX72LwkqpW2nmHb4K6QkyAAAGhfCDWrNNE0dyivQD1n52pWVrx+y8pXx87GTDuUUP9cl70Dip5CSUBJWJqSUzisdVTk5wDj8bBwvBwCLItygSm63qf1HjpeEmDz9UBJkdmXlK+9E9YNLgJ9NoZ5AUjz6Ufo6rOwoSplAEloaVkpHTwL8OBcBAHBKhBt4FLnc+unwUa/w8kNWvnYfyteJoooPCdkMqUNEsDq1DFGXyBBFRwQrPMjfc+jm17DiJ4cfN/oDANQ/ws0Z6HihS7sPFYeWH7Lytet/+frhUL72ZB+V013x2bYBdpvObRmsTq1C1KVViDqXPKJbBBNaAACNis/DzaJFi/TXv/5VmZmZ6t69u+bNm6eBAwdW2v+1117TU089pV27dik8PFxXXXWVnn76aUVERDRg1U1DzvGi4pGXrOLwsut/efrhUL72/XK80l8MBQfY1blVSEmICfWEmHZnN+MXJwCAJsGn4WbFihWaOnWqFi1apAEDBuiFF17Q8OHDtX37drVv375c/3//+98aP368nnnmGY0cOVL79+9XQkKCbr/9dq1atcoHe+B7pmkqO7+w5FBSntfhpKy8yu8HcnaQv7q0Ci03EhMVHsiJtgCAJs0wzcr+hq9/ffv21cUXX6zk5GRP2/nnn6/Ro0crKSmpXP+nn35aycnJ2r17t6dtwYIFeuqpp7R3795qbTM3N1fh4eHKyclRWFjY6e9EA3G7TR3IOe45H6bseTE5xyu/90jrsEB1iQxRp5bF4aU0yESEOBqwegAATk9Nvr99NnJTWFioLVu26MEHH/RqHzp0qDZu3FjhMv3799dDDz2kNWvWaPjw4crKytLf//53XX311ZVup6CgQAUFv45g5Obm1s0O1BOny62ffj5WLsT8kJWv40UV39TPMKT2zYPUpeRwUueWIeoSGapOLYMVGsh9jwAAZxafhZvs7Gy5XC5FRkZ6tUdGRurgwYMVLtO/f3+99tprGjt2rE6cOCGn06lrr71WCxYsqHQ7SUlJmjlzZp3WXhdOFLn046Gj+uFQvn4oORfmh6x8pWcfrfRW9f52Q9EtgksOIZWcD9MyROe2DFagPyf1AgAgNYITik8+v8M0zUrP+di+fbsmT56shx9+WMOGDVNmZqbuv/9+JSQkaPHixRUuM336dCUmJnqmc3Nz1a5du7rbgSrknSj6dfTlUL5+KPllUsbPxyo9qTcowO45jFT20aF5ECf1AgBQBZ+FmxYtWshut5cbpcnKyio3mlMqKSlJAwYM0P333y9J6tmzp4KDgzVw4EDNnj1bUVFR5ZZxOBxyOOr//JITRS6l7T1S7lDSwdwTlS4T3szf62Te0keb8GZcqA4AgFryWbgJCAhQ7969lZqaquuuu87TnpqaqlGjRlW4zLFjx+Tn512y3V58OMaH50VLkg4cOa5xL/6nwnmRYQ7PIaTOkaHFz61C1CIkgF8mAQBQx3x6WCoxMVG33HKLYmNjFRcXpxdffFEZGRlKSEiQVHxIaf/+/Vq2bJkkaeTIkbrjjjuUnJzsOSw1depUXXLJJWrTpo0vd0Xtmwfp3JbB6hgRXCbIFP9KKbwZJ/UCANBQfBpuxo4dq8OHD2vWrFnKzMxUjx49tGbNGnXo0EGSlJmZqYyMDE//+Ph45eXlaeHChbr33nt11llnafDgwXryySd9tQsefnab1t17ma/LAADgjOfT69z4QlO9zg0AAGeymnx/89MbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKYQbAABgKT4PN4sWLVJ0dLQCAwPVu3dvbdiw4ZT9CwoK9NBDD6lDhw5yOBzq1KmTUlJSGqhaAADQ2Pn5cuMrVqzQ1KlTtWjRIg0YMEAvvPCChg8fru3bt6t9+/YVLjNmzBj973//0+LFi9W5c2dlZWXJ6XQ2cOUAAKCxMkzTNH218b59++riiy9WcnKyp+3888/X6NGjlZSUVK7/+++/r3HjxunHH39U8+bNa7XN3NxchYeHKycnR2FhYbWuHQAANJyafH/77LBUYWGhtmzZoqFDh3q1Dx06VBs3bqxwmdWrVys2NlZPPfWUzjnnHHXt2lX33Xefjh8/Xul2CgoKlJub6/UAAADW5bPDUtnZ2XK5XIqMjPRqj4yM1MGDBytc5scff9S///1vBQYGatWqVcrOztakSZP0888/V3reTVJSkmbOnFnn9QMAgMbJ5ycUG4bhNW2aZrm2Um63W4Zh6LXXXtMll1yiESNGaO7cuVq6dGmlozfTp09XTk6O57F379463wcAANB4+GzkpkWLFrLb7eVGabKyssqN5pSKiorSOeeco/DwcE/b+eefL9M0tW/fPnXp0qXcMg6HQw6Ho26LBwAAjZbPRm4CAgLUu3dvpaamerWnpqaqf//+FS4zYMAAHThwQPn5+Z6277//XjabTW3btq3XegEAQNPg08NSiYmJevnll5WSkqIdO3bonnvuUUZGhhISEiQVH1IaP368p/+NN96oiIgI3Xrrrdq+fbvWr1+v+++/XxMnTlSzZs18tRsAAKAR8el1bsaOHavDhw9r1qxZyszMVI8ePbRmzRp16NBBkpSZmamMjAxP/5CQEKWmpuruu+9WbGysIiIiNGbMGM2ePdtXuwAAABoZn17nxhe4zg0AAE1PvV7nJj09vdaFAQAA1Lcah5vOnTvr8ssv1/Lly3XixIn6qAkAAKDWahxutm3bposuukj33nuvWrdurT/+8Y/64osv6qM2AACAGqtxuOnRo4fmzp2r/fv3a8mSJTp48KB+85vfqHv37po7d64OHTpUH3UCAABUS61/Cu7n56frrrtOb775pp588knt3r1b9913n9q2bavx48crMzOzLusEAACollqHm82bN2vSpEmKiorS3Llzdd9992n37t1at26d9u/fr1GjRtVlnQAAANVS4+vczJ07V0uWLNHOnTs1YsQILVu2TCNGjJDNVpyToqOj9cILL+i8886r82IBAACqUuNwk5ycrIkTJ+rWW29V69atK+zTvn17LV68+LSLAwAAqCku4gcAABq9er2I35IlS/TWW2+Va3/rrbf0yiuv1HR1AAAAdarG4eaJJ55QixYtyrW3atVKjz/+eJ0UBQAAUFs1Djc//fSToqOjy7V36NDB6yaXAAAAvlDjcNOqVSt9/fXX5dq3bdumiIiIOikKAACgtmocbsaNG6fJkyfr448/lsvlksvl0rp16zRlyhSNGzeuPmoEAACothr/FHz27Nn66aefdMUVV8jPr3hxt9ut8ePHc84NAADwuVr/FPz777/Xtm3b1KxZM11wwQXq0KFDXddWL/gpOAAATU9Nvr9rPHJTqmvXruratWttFwcAAKgXtQo3+/bt0+rVq5WRkaHCwkKveXPnzq2TwgAAAGqjxuHmo48+0rXXXqvo6Gjt3LlTPXr00J49e2Sapi6++OL6qBEAAKDaavxrqenTp+vee+/Vt99+q8DAQK1cuVJ79+7VoEGDdP3119dHjQAAANVW43CzY8cOTZgwQZLk5+en48ePKyQkRLNmzdKTTz5Z5wUCAADURI3DTXBwsAoKCiRJbdq00e7duz3zsrOz664yAACAWqjxOTf9+vXTZ599ppiYGF199dW699579c033+jtt99Wv3796qNGAACAaqtxuJk7d67y8/MlSTNmzFB+fr5WrFihzp0765lnnqnzAgEAAGqiRuHG5XJp79696tmzpyQpKChIixYtqpfCAAAAaqNG59zY7XYNGzZMR44cqa96AAAATkuNTyi+4IIL9OOPP9ZHLQAAAKetxuHmscce03333ad//vOfyszMVG5urtcDAADAl2p840yb7dc8ZBiG57VpmjIMQy6Xq+6qqwfcOBMAgKanXm+c+fHHH9e6MAAAgPpW43AzaNCg+qgDAACgTtQ43Kxfv/6U8y+99NJaFwMAAHC6ahxuLrvssnJtZc+9aezn3AAAAGur8a+lfvnlF69HVlaW3n//ffXp00cffPBBfdQIAABQbTUeuQkPDy/XNmTIEDkcDt1zzz3asmVLnRQGAABQGzUeualMy5YttXPnzrpaHQAAQK3UeOTm66+/9po2TVOZmZl64okn1KtXrzorDAAAoDZqHG4uvPBCGYahk6/9169fP6WkpNRZYQAAALVR43CTnp7uNW2z2dSyZUsFBgbWWVEAAAC1VeNw06FDh/qoAwAAoE7U+ITiyZMna/78+eXaFy5cqKlTp9ZJUQAAALVV43CzcuVKDRgwoFx7//799fe//71OigIAAKitGoebw4cPV3itm7CwMGVnZ9dJUQAAALVV43DTuXNnvf/+++Xa33vvPZ177rl1UhQAAEBt1fiE4sTERN111106dOiQBg8eLEn66KOPNGfOHM2bN6/OCwQAAKiJGoebiRMnqqCgQI899pgeffRRSVLHjh2VnJys8ePH13mBAAAANWGYJ1+NrwYOHTqkZs2aKSQkpC5rqle5ubkKDw9XTk6OwsLCfF0OAACohpp8f9fqIn5Op1NdunRRy5YtPe27du2Sv7+/OnbsWOOCAQAA6kqNTyiOj4/Xxo0by7V//vnnio+Pr4uaAAAAaq3G4Wbr1q0VXuemX79+SktLq5OiAAAAaqvG4cYwDOXl5ZVrz8nJkcvlqpOiAAAAaqvG4WbgwIFKSkryCjIul0tJSUn6zW9+U6fFAQAA1FSNTyh+6qmndOmll6pbt24aOHCgJGnDhg3KycnRxx9/XOcFAgAA1ESNR25iYmL09ddfa8yYMcrKylJeXp7Gjx+v77//Xk6nsz5qBAAAqLbTus6NJB05ckSvvfaaUlJSlJaW1ujPu+E6NwAAND01+f6u8chNqXXr1unmm29WmzZttHDhQg0fPlybN2+u7eoAAADqRI3Oudm3b5+WLl2qlJQUHT16VGPGjFFRUZFWrlypmJiY+qoRAACg2qo9cjNixAjFxMRo+/btWrBggQ4cOKAFCxbUZ20AAAA1Vu2Rmw8++ECTJ0/Wn/70J3Xp0qU+awIAAKi1ao/cbNiwQXl5eYqNjVXfvn21cOFCHTp0qD5rAwAAqLFqh5u4uDi99NJLyszM1B//+Ee98cYbOuecc+R2u5WamlrhVYsBAAAa2mn9FHznzp1avHixXn31VR05ckRDhgzR6tWr67K+OsdPwQEAaHoa5KfgktStWzc99dRT2rdvn15//fXTWRUAAECdOK1wU8put2v06NG1GrVZtGiRoqOjFRgYqN69e2vDhg3VWu6zzz6Tn5+fLrzwwhpvEwAAWFedhJvaWrFihaZOnaqHHnpIW7du1cCBAzV8+HBlZGSccrmcnByNHz9eV1xxRQNVCgAAmorTvv3C6ejbt68uvvhiJScne9rOP/98jR49WklJSZUuN27cOHXp0kV2u13vvPOO0tLSqr1NzrkBAKDpabBzbk5HYWGhtmzZoqFDh3q1Dx06VBs3bqx0uSVLlmj37t165JFHqrWdgoIC5ebmej0AAIB1+SzcZGdny+VyKTIy0qs9MjJSBw8erHCZXbt26cEHH9Rrr70mP7/qXX8wKSlJ4eHhnke7du1Ou3YAANB4+fScG0kyDMNr2jTNcm2S5HK5dOONN2rmzJnq2rVrtdc/ffp05eTkeB579+497ZoBAEDjVaMbZ9alFi1ayG63lxulycrKKjeaI0l5eXnavHmztm7dqrvuukuS5Ha7ZZqm/Pz89MEHH2jw4MHllnM4HHI4HPWzEwAAoNHx2chNQECAevfurdTUVK/21NRU9e/fv1z/sLAwffPNN0pLS/M8EhIS1K1bN6Wlpalv374NVToAAGjEfDZyI0mJiYm65ZZbFBsbq7i4OL344ovKyMhQQkKCpOJDSvv379eyZctks9nUo0cPr+VbtWqlwMDAcu0AAODM5dNwM3bsWB0+fFizZs1SZmamevTooTVr1qhDhw6SpMzMzCqveQMAAFCWT69z4wtc5wYAgKanSVznBgAAoD4QbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKX4PNwsWrRI0dHRCgwMVO/evbVhw4ZK+7799tsaMmSIWrZsqbCwMMXFxWnt2rUNWC0AAGjsfBpuVqxYoalTp+qhhx7S1q1bNXDgQA0fPlwZGRkV9l+/fr2GDBmiNWvWaMuWLbr88ss1cuRIbd26tYErBwAAjZVhmqbpq4337dtXF198sZKTkz1t559/vkaPHq2kpKRqraN79+4aO3asHn744Wr1z83NVXh4uHJychQWFlarugEAQMOqyfe3z0ZuCgsLtWXLFg0dOtSrfejQodq4cWO11uF2u5WXl6fmzZtX2qegoEC5ubleDwAAYF0+CzfZ2dlyuVyKjIz0ao+MjNTBgwertY45c+bo6NGjGjNmTKV9kpKSFB4e7nm0a9futOoGAACNm89PKDYMw2vaNM1ybRV5/fXXNWPGDK1YsUKtWrWqtN/06dOVk5Pjeezdu/e0awYAAI2Xn6823KJFC9nt9nKjNFlZWeVGc062YsUK3XbbbXrrrbd05ZVXnrKvw+GQw+E47XoBAEDT4LORm4CAAPXu3Vupqale7ampqerfv3+ly73++uuKj4/X3/72N1199dX1XSYAAGhifDZyI0mJiYm65ZZbFBsbq7i4OL344ovKyMhQQkKCpOJDSvv379eyZcskFQeb8ePH69lnn1W/fv08oz7NmjVTeHi4z/YDAAA0Hj4NN2PHjtXhw4c1a9YsZWZmqkePHlqzZo06dOggScrMzPS65s0LL7wgp9OpO++8U3feeaenfcKECVq6dGlDlw8AABohn17nxhe4zg0AAE1Pk7jODQAAQH0g3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEvxebhZtGiRoqOjFRgYqN69e2vDhg2n7P/pp5+qd+/eCgwM1Lnnnqvnn3++gSoFAABNgU/DzYoVKzR16lQ99NBD2rp1qwYOHKjhw4crIyOjwv7p6ekaMWKEBg4cqK1bt+rPf/6zJk+erJUrVzZw5QAAoLEyTNM0fbXxvn376uKLL1ZycrKn7fzzz9fo0aOVlJRUrv8DDzyg1atXa8eOHZ62hIQEbdu2TZs2barWNnNzcxUeHq6cnByFhYWd/k4AAIB6V5Pvb5+N3BQWFmrLli0aOnSoV/vQoUO1cePGCpfZtGlTuf7Dhg3T5s2bVVRUVG+1AgCApsPPVxvOzs6Wy+VSZGSkV3tkZKQOHjxY4TIHDx6ssL/T6VR2draioqLKLVNQUKCCggLPdE5OjqTiBAgAAJqG0u/t6hxw8lm4KWUYhte0aZrl2qrqX1F7qaSkJM2cObNce7t27WpaKgAA8LG8vDyFh4efso/Pwk2LFi1kt9vLjdJkZWWVG50p1bp16wr7+/n5KSIiosJlpk+frsTERM+02+3Wzz//rIiIiFOGqNrIzc1Vu3bttHfv3jPyfJ4zff8l3oMzff8l3gP2/8zef6n+3gPTNJWXl6c2bdpU2ddn4SYgIEC9e/dWamqqrrvuOk97amqqRo0aVeEycXFxevfdd73aPvjgA8XGxsrf37/CZRwOhxwOh1fbWWeddZrVn1pYWNgZ+6GW2H+J9+BM33+J94D9P7P3X6qf96CqEZtSPv0peGJiol5++WWlpKRox44duueee5SRkaGEhARJxaMu48eP9/RPSEjQTz/9pMTERO3YsUMpKSlavHix7rvvPl/tAgAAaGR8es7N2LFjdfjwYc2aNUuZmZnq0aOH1qxZow4dOkiSMjMzva55Ex0drTVr1uiee+7Rc889pzZt2mj+/Pn63e9+56tdAAAAjYzPTyieNGmSJk2aVOG8pUuXlmsbNGiQvvrqq3quqnYcDoceeeSRcofBzhRn+v5LvAdn+v5LvAfs/5m9/1LjeA98ehE/AACAuubze0sBAADUJcINAACwFMINAACwFMINAACwFMJNHVm0aJGio6MVGBio3r17a8OGDb4uqcGsX79eI0eOVJs2bWQYht555x1fl9SgkpKS1KdPH4WGhqpVq1YaPXq0du7c6euyGlRycrJ69uzpuWhXXFyc3nvvPV+X5TNJSUkyDENTp071dSkNZsaMGTIMw+vRunVrX5fVoPbv36+bb75ZERERCgoK0oUXXqgtW7b4uqwG07Fjx3KfAcMwdOeddzZ4LYSbOrBixQpNnTpVDz30kLZu3aqBAwdq+PDhXtfosbKjR4+qV69eZbTFtQAACPtJREFUWrhwoa9L8YlPP/1Ud955p/7zn/8oNTVVTqdTQ4cO1dGjR31dWoNp27atnnjiCW3evFmbN2/W4MGDNWrUKP33v//1dWkN7ssvv9SLL76onj17+rqUBte9e3dlZmZ6Ht98842vS2owv/zyiwYMGCB/f3+999572r59u+bMmVPvV8RvTL788kuvf//U1FRJ0vXXX9/wxZg4bZdccomZkJDg1XbeeeeZDz74oI8q8h1J5qpVq3xdhk9lZWWZksxPP/3U16X41Nlnn22+/PLLvi6jQeXl5ZldunQxU1NTzUGDBplTpkzxdUkN5pFHHjF79erl6zJ85oEHHjB/85vf+LqMRmXKlClmp06dTLfb3eDbZuTmNBUWFmrLli0aOnSoV/vQoUO1ceNGH1UFX8rJyZEkNW/e3MeV+IbL5dIbb7yho0ePKi4uztflNKg777xTV199ta688kpfl+ITu3btUps2bRQdHa1x48bpxx9/9HVJDWb16tWKjY3V9ddfr1atWumiiy7SSy+95OuyfKawsFDLly/XxIkT6/wm1dVBuDlN2dnZcrlc5e5kHhkZWe4O5rA+0zSVmJio3/zmN+rRo4evy2lQ33zzjUJCQuRwOJSQkKBVq1YpJibG12U1mDfeeENfffWVkpKSfF2KT/Tt21fLli3T2rVr9dJLL+ngwYPq37+/Dh8+7OvSGsSPP/6o5ORkdenSRWvXrlVCQoImT56sZcuW+bo0n3jnnXd05MgRxcfH+2T7Pr/9glWcnExN0/RJWoVv3XXXXfr666/173//29elNLhu3bopLS1NR44c0cqVKzVhwgT9f3v3F9JU38AB/Ds3t7YxYrrMjdIkK7NMykVMu6ndbEVQGUYsWUTISockXvWHLMLuioIYDGx0YQiD/iwiNcu8EMIIVkOWFUQFISu6yBXtov2ei3jHO/a8z/s+72Pnp8fvBw6cnbM/3x948eV3fsczNja2IArOhw8f0NnZieHhYSxatEh2HCm8Xm9uv66uDi6XCytXrsT169fR1dUlMZkystksnE4nent7AQAbN27E5OQkQqFQ3gOgF4q+vj54vV44HA4pv8+Zm3/IZrNBq9UWzNKkUqmC2RxSt2AwiFgshtHRUSxbtkx2HMXp9XpUV1fD6XTiwoULqK+vx+XLl2XHUsSzZ8+QSqXQ0NAAnU4HnU6HsbExXLlyBTqdDj9//pQdUXFmsxl1dXV4/fq17CiKsNvtBUV+7dq1C+bGkn/37t07jIyM4MiRI9IysNz8Q3q9Hg0NDblV4f/y4MEDNDY2SkpFShJCoKOjAzdv3sSjR49QVVUlO9KcIIRAJpORHUMRbrcbiUQC8Xg8tzmdTvh8PsTjcWi1WtkRFZfJZJBMJmG322VHUURTU1PBv4B49eoVKisrJSWSJxKJoKysDDt37pSWgZelZkFXVxdaW1vhdDrhcrkQDofx/v17BAIB2dEUkU6n8ebNm9zrt2/fIh6Po6SkBBUVFRKTKaO9vR03btzAnTt3YLFYcrN4ixcvhtFolJxOGSdOnIDX68Xy5csxMzODgYEBPH78GIODg7KjKcJisRSssTKbzSgtLV0wa6+6u7uxa9cuVFRUIJVK4fz58/j69Sv8fr/saIo4fvw4Ghsb0dvbi5aWFkxMTCAcDiMcDsuOpqhsNotIJAK/3w+dTmLFUPz+LJW6evWqqKysFHq9XmzatGlB3QY8OjoqABRsfr9fdjRF/NnYAYhIJCI7mmIOHz6c+/tfsmSJcLvdYnh4WHYsqRbareD79+8XdrtdFBcXC4fDIfbu3SsmJydlx1LU3bt3xfr164XBYBA1NTUiHA7LjqS4oaEhAUBMTU1JzaERQgg5tYqIiIho9nHNDREREakKyw0RERGpCssNERERqQrLDREREakKyw0RERGpCssNERERqQrLDREREakKyw0REX49/Pb27duyYxDRLGC5ISLpDh06BI1GU7B5PB7Z0YhoHuKzpYhoTvB4PIhEInnHDAaDpDRENJ9x5oaI5gSDwYDy8vK8zWq1Avh1ySgUCsHr9cJoNKKqqgrRaDTv84lEAtu3b4fRaERpaSna2tqQTqfz3nPt2jWsW7cOBoMBdrsdHR0deec/f/6MPXv2wGQyYdWqVYjFYr930ET0W7DcENG8cPr0aTQ3N+P58+c4ePAgDhw4gGQyCQD4/v07PB4PrFYrnj59img0ipGRkbzyEgqF0N7ejra2NiQSCcRiMVRXV+f9xtmzZ9HS0oIXL15gx44d8Pl8+PLli6LjJKJZIPWxnUREQgi/3y+0Wq0wm81527lz54QQv568HggE8j6zZcsWcfToUSGEEOFwWFitVpFOp3Pn7927J4qKisT09LQQQgiHwyFOnjz5HzMAEKdOncq9TqfTQqPRiPv378/aOIlIGVxzQ0RzwrZt2xAKhfKOlZSU5PZdLlfeOZfLhXg8DgBIJpOor6+H2WzOnW9qakI2m8XU1BQ0Gg0+fvwIt9v9lxk2bNiQ2zebzbBYLEilUv/3mIhIDpYbIpoTzGZzwWWi/0aj0QAAhBC5/T97j9Fo/J++r7i4uOCz2Wz2b2UiIvm45oaI5oUnT54UvK6pqQEA1NbWIh6P49u3b7nz4+PjKCoqwurVq2GxWLBixQo8fPhQ0cxEJAdnbohoTshkMpiens47ptPpYLPZAADRaBROpxNbt25Ff38/JiYm0NfXBwDw+Xw4c+YM/H4/enp68OnTJwSDQbS2tmLp0qUAgJ6eHgQCAZSVlcHr9WJmZgbj4+MIBoPKDpSIfjuWGyKaEwYHB2G32/OOrVmzBi9fvgTw606mgYEBHDt2DOXl5ejv70dtbS0AwGQyYWhoCJ2dndi8eTNMJhOam5tx8eLF3Hf5/X78+PEDly5dQnd3N2w2G/bt26fcAIlIMRohhJAdgojor2g0Gty6dQu7d++WHYWI5gGuuSEiIiJVYbkhIiIiVeGaGyKa83j1nIj+Ds7cEBERkaqw3BAREZGqsNwQERGRqrDcEBERkaqw3BAREZGqsNwQERGRqrDcEBERkaqw3BAREZGqsNwQERGRqvwBHswZvSIlO3YAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(train_a, label='train accuracy')\n",
    "plt.plot(val_a, label='validation accuracy')\n",
    "\n",
    "plt.title('Training history')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()\n",
    "plt.ylim([0, 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbT4YzHFh1s7"
   },
   "source": [
    "Accuracy of Pos/Neg on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 28664,
     "status": "ok",
     "timestamp": 1695329081374,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L1ZzFERMAQk9",
    "outputId": "56411273-fc8c-476f-f660-21239bd321ad",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.8307692307692307\n",
      "F1-Macro:  0.8086956989232211\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "test_acc, _, test_f1 = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "\n",
    "print(\"Accuracy: \",test_acc.item())\n",
    "print(\"F1-Macro: \",test_f1.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695329212064,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "d7DaFLMiAWps",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader):\n",
    "  model = model.eval()\n",
    "\n",
    "  review = []\n",
    "  predictions = []\n",
    "\n",
    "  prediction_probs = []\n",
    "  real_values = []\n",
    "\n",
    "  pos_scores = []\n",
    "  neg_scores = []\n",
    "\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "\n",
    "      reviews = d[\"review\"]\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      sentiments = d[\"sentiments\"].to(device)\n",
    "      emotions = d[\"emotions\"].to(device)\n",
    "\n",
    "\n",
    "      scores,outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_scores=True,\n",
    "      )\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "\n",
    "      probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "      review.extend(reviews)\n",
    "      predictions.extend(preds)\n",
    "\n",
    "      prediction_probs.extend(probs)\n",
    "      real_values.extend(sentiments)\n",
    "\n",
    "      pos_scores.extend(scores[0])\n",
    "      neg_scores.extend(scores[1])\n",
    "\n",
    "\n",
    "  predictions = torch.stack(predictions).cpu()\n",
    "  prediction_probs = torch.stack(prediction_probs).cpu()\n",
    "  real_values = torch.stack(real_values).cpu()\n",
    "  pos_scores = torch.stack(pos_scores).cpu()\n",
    "  neg_scores = torch.stack(neg_scores).cpu()\n",
    "  return review, predictions, prediction_probs, real_values, pos_scores, neg_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "executionInfo": {
     "elapsed": 27254,
     "status": "ok",
     "timestamp": 1695329239311,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "kZhYtki1AYx8",
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_review_texts, y_pred, y_pred_probs, y_test, pos_scores, neg_scores = get_predictions(\n",
    "  model,\n",
    "  test_data_loader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695329239312,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "mO1FHn0GAbCU",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix\n",
    "class_names = ['negative', 'positive']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PnBSQFJuh_As"
   },
   "source": [
    "Pos/Neg Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695329239312,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "J9-01snpAcM8",
    "outputId": "4d6f7fff-9a14-49bf-c847-d2e8ba79f5f7",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative       0.86      0.88      0.87      1603\n",
      "    positive       0.79      0.76      0.77       932\n",
      "\n",
      "    accuracy                           0.83      2535\n",
      "   macro avg       0.82      0.82      0.82      2535\n",
      "weighted avg       0.83      0.83      0.83      2535\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_test, y_pred, target_names=class_names))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
