{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695324378707,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "F7toc08bpdQ1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from transformers import RobertaTokenizer, RobertaModel, RobertaConfig\n",
    "\n",
    "from transformers import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score\n",
    "import textwrap\n",
    "import math\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1695323483676,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "MLk4JWizs4S9",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('/home/m_nsu/ICLR/Datasets/IMDB/train.csv')\n",
    "df_val = pd.read_csv('/home/m_nsu/ICLR/Datasets/IMDB/val.csv')\n",
    "df_test = pd.read_csv('/home/m_nsu/ICLR/Datasets/IMDB/test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 48,
     "status": "ok",
     "timestamp": 1695323754433,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "vX1DfZpN1W01",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train.dropna(inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 49,
     "status": "ok",
     "timestamp": 1695323754437,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e0NG9J-oq_wW",
    "outputId": "f1bb3869-364e-4c7a-8c3f-a62ce4b182c7",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>review</th>\n",
       "      <th>sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>48539</td>\n",
       "      <td>This was a very funny movie not Oscarworthy bu...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>26308</td>\n",
       "      <td>I know this film has had a fairly rough ride f...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>14603</td>\n",
       "      <td>Being stuck in bed with the flu and feeling to...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>45794</td>\n",
       "      <td>This isnt exactly a great film but I admire th...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>21688</td>\n",
       "      <td>It is difficult to rate a writerdirectors firs...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0                                             review  sentiment\n",
       "0       48539  This was a very funny movie not Oscarworthy bu...          0\n",
       "1       26308  I know this film has had a fairly rough ride f...          0\n",
       "2       14603  Being stuck in bed with the flu and feeling to...          1\n",
       "3       45794  This isnt exactly a great film but I admire th...          0\n",
       "4       21688  It is difficult to rate a writerdirectors firs...          1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695323756839,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UmEtmaFNrakz",
    "tags": []
   },
   "outputs": [],
   "source": [
    "MAX_LEN = 200\n",
    "RANDOM_SEED = 42\n",
    "device = torch.device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695323756840,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6oWORty0p8Xo",
    "outputId": "cb12350b-753a-4c61-c7e8-3c9427f000e9",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "PRE_TRAINED_MODEL_NAME = 'roberta-large'\n",
    "config = RobertaConfig.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "tokenizer = RobertaTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "clear_output()\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695323770760,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "OtZt1p7ys7XD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class IMDBDataset(Dataset):\n",
    "\n",
    "  def __init__(self, reviews, sentiments, tokenizer, max_len):\n",
    "    self.reviews = reviews\n",
    "    self.sentiments = sentiments\n",
    "    self.tokenizer = tokenizer\n",
    "    self.max_len = max_len\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.reviews)\n",
    "\n",
    "  def __getitem__(self, item):\n",
    "    review = str(self.reviews[item])\n",
    "    sentiment = self.sentiments[item]\n",
    "\n",
    "\n",
    "    encoding = self.tokenizer.encode_plus(\n",
    "      review,\n",
    "      add_special_tokens=True,\n",
    "      max_length=self.max_len,\n",
    "      return_token_type_ids=False,\n",
    "      padding='max_length',\n",
    "      truncation = True,\n",
    "      return_attention_mask=True,\n",
    "      return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    return {\n",
    "      'review': review,\n",
    "      'input_ids': encoding['input_ids'].flatten(),\n",
    "      'attention_mask': encoding['attention_mask'].flatten(),\n",
    "      'sentiments': torch.tensor(sentiment, dtype=torch.long),\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695325189268,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UuOujQajtL5f",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_data_loader(df, tokenizer, max_len, batch_size):\n",
    "  ds = IMDBDataset(\n",
    "    reviews=df.review.to_numpy(),\n",
    "    sentiments=df['sentiment'].to_numpy(),\n",
    "    tokenizer=tokenizer,\n",
    "    max_len=max_len\n",
    "  )\n",
    "\n",
    "  return DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=8\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695323770761,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "3zzA4eBytOqj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "\n",
    "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695325187523,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AoUfRPy0tQgk",
    "outputId": "4cc31257-8fe1-4cce-883e-8aff4ebf2185",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['review', 'input_ids', 'attention_mask', 'sentiments'])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = next(iter(train_data_loader))\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1695325688712,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Y3Nil-yatUqF",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Classifier(nn.Module):\n",
    "  def __init__(self):\n",
    "    super(Classifier, self).__init__()\n",
    "    self.bert = RobertaModel.from_pretrained(PRE_TRAINED_MODEL_NAME,config=config)\n",
    "    self.FC = nn.Linear(config.hidden_size,2, bias=False)\n",
    "\n",
    "\n",
    "  def forward(self, input_ids, attention_mask):\n",
    "    with torch.no_grad():\n",
    "      pooled_output = self.bert(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_dict = False\n",
    "      )\n",
    "    pooled_output = torch.mean(pooled_output[0], dim=1) # Taking Averge pooled last layer embedding\n",
    "\n",
    "    binary_out = self.FC(pooled_output)\n",
    "    \n",
    "    return binary_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "id": "HWZ37gsztWzL",
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = Classifier()\n",
    "model = model.to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1695325221207,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWy57v2CxxCM",
    "tags": []
   },
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    if name.startswith('bert'):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1695325221207,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e5iLu13CYlux",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#for name, param in model.named_parameters():\n",
    "#    print(name, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 21,
     "status": "ok",
     "timestamp": 1695325221208,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "nZpqz6yDtYZ4",
    "outputId": "c0ba4f97-cad5-46b8-e1ff-e818ff5b4035",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 200])\n",
      "torch.Size([32, 200])\n"
     ]
    }
   ],
   "source": [
    "input_ids = data['input_ids'].to(device)\n",
    "attention_mask = data['attention_mask'].to(device)\n",
    "sentiments = data['sentiments'].to(device)\n",
    "\n",
    "print(input_ids.shape) # batch size x seq length\n",
    "print(attention_mask.shape) # batch size x seq length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1695325221208,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NwvA-zp7vqc1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#del test\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "executionInfo": {
     "elapsed": 603,
     "status": "ok",
     "timestamp": 1695325221800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "-9Z37OXOtb0q",
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs = model(input_ids, attention_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1695325221800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NHFM3QUhg3n0",
    "outputId": "602bd955-db78-44a6-931e-e5e901bb1856",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.7528,  0.6537],\n",
       "        [-0.7695,  0.5806],\n",
       "        [-0.7913,  0.5575],\n",
       "        [-0.7840,  0.5914],\n",
       "        [-0.7829,  0.5687],\n",
       "        [-0.8008,  0.6121],\n",
       "        [-0.7681,  0.6100],\n",
       "        [-0.7864,  0.5589],\n",
       "        [-0.7426,  0.5783],\n",
       "        [-0.7371,  0.5387],\n",
       "        [-0.7932,  0.5525],\n",
       "        [-0.7728,  0.5873],\n",
       "        [-0.8216,  0.5862],\n",
       "        [-0.7651,  0.5535],\n",
       "        [-0.8383,  0.5820],\n",
       "        [-0.7597,  0.6138],\n",
       "        [-0.8121,  0.5654],\n",
       "        [-0.8729,  0.5414],\n",
       "        [-0.8110,  0.5213],\n",
       "        [-0.8057,  0.5856],\n",
       "        [-0.7508,  0.5852],\n",
       "        [-0.8092,  0.5742],\n",
       "        [-0.7899,  0.6172],\n",
       "        [-0.7634,  0.6138],\n",
       "        [-0.7900,  0.5054],\n",
       "        [-0.7907,  0.6035],\n",
       "        [-0.7974,  0.5155],\n",
       "        [-0.7545,  0.5729],\n",
       "        [-0.7413,  0.5482],\n",
       "        [-0.7482,  0.6084],\n",
       "        [-0.7687,  0.5188],\n",
       "        [-0.7727,  0.5884]], device='cuda:0', grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1022,
     "status": "ok",
     "timestamp": 1695325845248,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cjMVWA5a_6lf",
    "outputId": "bae8764e-0ce8-4118-e676-267b2c3ac06b",
    "tags": []
   },
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=0.001)\n",
    "total_steps = len(train_data_loader) * EPOCHS\n",
    "\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "  optimizer,\n",
    "  num_warmup_steps=math.floor((1./5)*total_steps),\n",
    "  num_training_steps=total_steps\n",
    ")\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695325845969,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cLFDb4pzbx9W",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_epoch(\n",
    "  model,\n",
    "  data_loader,\n",
    "  loss_fn,\n",
    "  optimizer,\n",
    "  device,\n",
    "  scheduler,\n",
    "  n_examples\n",
    "):\n",
    "  model = model.train()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  for d in data_loader:\n",
    "    input_ids = d[\"input_ids\"].to(device)\n",
    "    attention_mask = d[\"attention_mask\"].to(device)\n",
    "    sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "    outputs = model(\n",
    "      input_ids=input_ids,\n",
    "      attention_mask=attention_mask\n",
    "    ).to(device)\n",
    "\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    loss = loss_fn(outputs, sentiments)\n",
    "\n",
    "    correct_predictions += torch.sum(preds == sentiments)\n",
    "    losses.append(loss.item())\n",
    "\n",
    "\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695325845971,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "z4GAdIawtUue",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def eval_model(model, data_loader, loss_fn, device, n_examples, on_new=False):\n",
    "  model = model.eval()\n",
    "\n",
    "  losses = []\n",
    "  f1s = []\n",
    "\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      ).to(device)\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      loss = loss_fn(outputs, sentiments)\n",
    "\n",
    "      correct_predictions += torch.sum(preds == sentiments)\n",
    "      losses.append(loss.item())\n",
    "\n",
    "      f1s.append(f1_score(sentiments.cpu(), preds.cpu(), average='macro'))\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses), np.mean(f1s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "executionInfo": {
     "elapsed": 3203929,
     "status": "ok",
     "timestamp": 1695329051836,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "IqdIHJsrANr0",
    "outputId": "ee79ec9a-a033-47e0-bc8b-a7b71698465c",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "----------\n",
      "Train loss 0.5224748132705689 accuracy 0.7841750000000001\n",
      "Val   loss 0.3476291074874295 accuracy 0.8688\n",
      "\n",
      "Epoch 2/8\n",
      "----------\n",
      "Train loss 0.297849018996954 accuracy 0.88365\n",
      "Val   loss 0.28951823398186144 accuracy 0.8814000000000001\n",
      "\n",
      "Epoch 3/8\n",
      "----------\n",
      "Train loss 0.26300194618701933 accuracy 0.894\n",
      "Val   loss 0.274369531162795 accuracy 0.8874000000000001\n",
      "\n",
      "Epoch 4/8\n",
      "----------\n",
      "Train loss 0.254089415243268 accuracy 0.8965000000000001\n",
      "Val   loss 0.2619149535894394 accuracy 0.8918\n",
      "\n",
      "Epoch 5/8\n",
      "----------\n",
      "Train loss 0.2481885947406292 accuracy 0.900225\n",
      "Val   loss 0.25632380136543775 accuracy 0.8962\n",
      "\n",
      "Epoch 6/8\n",
      "----------\n",
      "Train loss 0.24490812806487083 accuracy 0.9011750000000001\n",
      "Val   loss 0.25351111716620484 accuracy 0.8976000000000001\n",
      "\n",
      "Epoch 7/8\n",
      "----------\n",
      "Train loss 0.24209747795164585 accuracy 0.90375\n",
      "Val   loss 0.2520762600574144 accuracy 0.8986000000000001\n",
      "\n",
      "Epoch 8/8\n",
      "----------\n",
      "Train loss 0.2414363041251898 accuracy 0.9029\n",
      "Val   loss 0.25222110926250746 accuracy 0.8986000000000001\n",
      "\n",
      "CPU times: user 2h 32min 47s, sys: 23.9 s, total: 2h 33min 10s\n",
      "Wall time: 2h 33min 16s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_a = []\n",
    "train_l = []\n",
    "val_a = []\n",
    "val_l = []\n",
    "best_accuracy = 0\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "  print(f'Epoch {epoch + 1}/{EPOCHS}')\n",
    "  print('-' * 10)\n",
    "\n",
    "  train_acc, train_loss = train_epoch(\n",
    "    model,\n",
    "    train_data_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device,\n",
    "    scheduler,\n",
    "    len(df_train)\n",
    "  )\n",
    "\n",
    "  print(f'Train loss {train_loss} accuracy {train_acc}')\n",
    "\n",
    "  val_acc, val_loss, val_f1 = eval_model(\n",
    "    model,\n",
    "    val_data_loader,\n",
    "    loss_fn,\n",
    "    device,\n",
    "    len(df_val)\n",
    "  )\n",
    "\n",
    "  print(f'Val   loss {val_loss} accuracy {val_acc}')\n",
    "  print()\n",
    "\n",
    "  train_a.append(train_acc)\n",
    "  train_l.append(train_loss)\n",
    "  val_a.append(val_acc)\n",
    "  val_l.append(val_loss)\n",
    "\n",
    "  if val_acc > best_accuracy:\n",
    "    torch.save(model.state_dict(), 'baseline_bert_best_model_state.bin')\n",
    "    best_accuracy = val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "executionInfo": {
     "elapsed": 51,
     "status": "ok",
     "timestamp": 1695329137842,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "FowMSU5U7SDQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_a = [i.item() for i in train_a]\n",
    "train_l = [i.item() for i in train_l]\n",
    "val_a = [i.item() for i in val_a]\n",
    "val_l = [i.item() for i in val_l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 480
    },
    "executionInfo": {
     "elapsed": 2888,
     "status": "ok",
     "timestamp": 1695329143450,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aUQbxyTEAPhM",
    "outputId": "b103ca88-1886-4f16-ea21-5b966a08fe7c",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHFCAYAAAAOmtghAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOzdeXwU9f3H8ffs5oRAuEO4I5cE5AxCoJFDAUEQrBXQKkTUFtECIlIpvyogGrSCCEi0LQERq0hBizaKqaBQoXJI8CAcxWA4QiNXEhBy7M7vjyRLNgnkIMkmw+v5cB/JfOc7M58ZeLhvvvPdWcM0TVMAAAAWYfN0AQAAAOWJcAMAACyFcAMAACyFcAMAACyFcAMAACyFcAMAACyFcAMAACyFcAMAACyFcAMAACyFcANYiGEYJXp9/vnn5XK8S5cuyTAMzZ8/v0zb9+7dW7fffnu51FJajRs31q9+9ati+33yyScyDEP/+c9/SrX/xYsXa/Xq1WUtD8A18PJ0AQDKz/bt292Wn3vuOW3evFmbNm1yaw8NDS2X4/n6+mr79u1q0aJFmbZfvny57HZ7udRSUcLDw7V9+3Z16tSpVNstXrxYbdq00f33319BlQG4EsINYCG9e/d2W27YsKFsNluh9ivJzMyU3W4vceAwDKPE+y5Kx44dy7xtZQkMDLymcyxPDodDDodDPj4+ni4FqNK4LQVcp/Jut6xZs0aTJ09WcHCw/Pz8dPToUSUnJ2vixInq0KGDatasqaCgIN12222FRoaKui31+uuvyzAMffnll3rkkUdUv359NWjQQPfcc4/+97//uW1f8LbU/v37ZRiGlixZohdffFEtW7ZUQECA+vbtq927dxc6h2XLlqlNmzby9fXVTTfdpLVr12rs2LG68cYbS3wdPvzwQ3Xt2lX+/v4KDQ0tdCupqNtSBw8e1D333KPg4GD5+vqqcePGGjRokL7//ntJObe8Dh8+rI0bN7puBeavKTExUffee68aNmwoX19fhYaGavHixcr/PcZ512LRokWaPXu2WrVqJR8fH8XGxiogIEBTpkwpdC4HDhyQzWbTkiVLSnz+gBUxcgNc55588kndcsst+utf/yqn06m6desqKSlJ3t7emjNnjoKCgpSenq61a9cqIiJCW7duVXh4eLH7HT9+vO6880698847SkxM1IwZM/Tggw8qNja22G0XLlyom266SUuWLJHD4dCsWbM0dOhQJSYmqmbNmpJybvtMmTJFY8eO1eLFi3XmzBnNnDlTWVlZ8vf3L9G579y5UwcOHNDTTz+tBg0aKDo6Wg888IDatWunm2++uchtTNPU7bffLl9fX7388stq3ry5Tp06pa1bt+rcuXOSpNjYWI0cOVLNmjXTK6+8IkmumpKTkxUeHi7DMBQVFaVmzZrpgw8+0JQpU3TkyBEtXLjQ7Xgvv/yyQkNDtXDhQgUEBCg0NFTjxo3Tm2++qRdeeMF1PSTptddeU82aNTV+/PgSnT9gWSYAyxo/frxZs2bNItd9/PHHpiRz8ODBxe4nOzvbzMrKMvv27Wvee++9rvaLFy+aksyoqChXW3R0tCnJnDZtmts+5s6da0oyz5w542rr1auXOWTIENdyQkKCKckMCwsznU6nq33Lli2mJPP99983TdM0MzMzzfr165v9+vVzO8Z///tf0263m+3bty/2nIKCgsyaNWuaJ06ccLWdP3/erFWrljllyhRXW9512r59u2mapnns2DFTkvn6669fdf+tW7d2O7c8U6dONQ3DMOPj493aH3zwQdNms5mJiYlu16JDhw5mdna2W9+EhATTMAwzOjra1Zaenm7Wrl3bfPTRR4s9d8DquC0FXOfuvvvuQm2maWrJkiXq1q2b/Pz85OXlJW9vb3355ZdKSEgo0X7vvPNOt+XOnTtLkpKSkorddvjw4TIMo9C2P/74oyTpu+++0+nTpzV69Gi37Vq3bq2ePXuWqD5J6tmzp4KDg13LNWvWVOvWrV3HKUrjxo3VokULvfDCC3r11Ve1d+9eOZ3OEh9z06ZN6tatm7p06eLWHhkZKafTWeiTbKNGjSo0B+rGG2/UoEGD9Nprr7naVq1apbS0ND322GMlrgWwKsINcJ3L/+aeJyoqSpMnT1ZERITWr1+vr776Sjt37tTAgQN18eLFEu23fv36bsu+vr6SVKLti9v29OnTkqSgoKBC2xbVVtLj5B3rajXa7XZt3rxZAwYM0PPPP6+uXbsqKChI06ZN04ULF4o95unTp4u85k2aNHGtz6+ovpI0ZcoUfffdd9qyZYuknFtSAwYMqBaTtIGKxpwb4DqXf4Qkz+rVq3X77bdr8eLFbu2pqamVVdZV5YWSghOUJenkyZMVfvwbbrhBK1eulJQz8ffdd9/Vc889J6fTqUWLFl112/r16ys5OblQ+4kTJyRJDRo0cGsv6s9HkoYOHaq2bdtq6dKlys7O1r59+zR37twynA1gPYzcACjEMAzXaEmeXbt26euvv/ZQRe46deqkevXqac2aNW7thw8f1q5duyq1lhtvvFGzZ89Wu3bt3K7PlUaAbr31VsXHx7s+WZVn1apVstls6t+/f4mOaxiGfve73+n999/Xs88+q2bNmmnUqFHXdC6AVRBuABQyfPhwffjhh5o3b542bdqkpUuX6o477lCrVq08XZokydvbW88++6y2bNmie++9Vx9//LFWr16tIUOGqEmTJrLZKu5/bTt27NCAAQP02muvaePGjdq0aZOefvppHThwQIMGDXL1u+mmm7Rr1y79/e9/165du1xh5qmnnlLDhg01ZMgQxcTEaOPGjXrssce0fPlyTZ06VS1btixxLZGRkapRo4b+/e9/a+LEiVX+gYhAZeG2FIBCZs+erczMTC1btkzPP/+8OnXqpBUrVmjVqlWKj4/3dHmSpMmTJ8tut2vhwoVav369brjhBs2ZM0dvvfWW0tLSKuy4zZo1U4sWLbRkyRIdO3ZMNptNrVu31uLFizVp0iRXv+eff16nTp3Sgw8+qPPnz6t9+/bav3+/goODtX37ds2cOVNPPfWU0tPT1bp1ay1atEiTJ08uVS21atXSsGHDtH79ej3yyCPlfapAtWWYZr6nRgFANXb69Gm1bdtW999/f6H5QlZ08eJFtWjRQkOHDtWqVas8XQ5QZTByA6BaSkpK0sKFC9WvXz/Vq1dPiYmJWrBggTIyMvS73/3O0+VVqJSUFB08eFBvvPGGzp49qxkzZni6JKBKIdwAqJb8/Px06NAhvfPOOzpz5owCAgLUp08frVy5Um3btvV0eRVq/fr1evTRR9W0aVP95S9/KfWXegJWx20pAABgKR79tNSWLVs0YsQINWnSRIZh6IMPPih2my+++EI9evSQn5+fbrjhBr3++uuVUCkAAKguPBpuLly4oC5dumjp0qUl6p+YmKhhw4YpIiJCe/bs0R/+8AdNnjxZ69atq+BKAQBAdVFlbksZhqH333//qg+h+v3vf68NGza4fbfNxIkTtXfvXm3fvr0yygQAAFVctZpQvH37dg0ePNitbciQIVq+fLmysrLk7e1daJuMjAxlZGS4lp1Op86cOaP69etf8bHmAACgajFNU+np6SV6UGe1CjcnT54s9KV4QUFBys7O1qlTp674BYBz5syprBIBAEAFOnr0qJo1a3bVPtUq3EiFv0Qu767alUZhZs6cqWnTprmWU1NT1aJFCx09elS1a9euuEIBAEC5SUtLU/PmzVWrVq1i+1arcNO4ceNC3/ibkpIiLy8v17cEF+Tr61voCwAlqXbt2oQbAACqmZJMKalWX5wZHh6uuLg4t7ZPP/1UYWFhRc63AQAA1x+Phpvz588rPj7e9UV8iYmJio+PV1JSkqScW0rjxo1z9Z84caJ+/PFHTZs2TQkJCYqJidHy5cs1ffp0j9QPAACqHo/eltq1a5cGDBjgWs6bGzN+/HitXLlSycnJrqAjSSEhIYqNjdUTTzyh1157TU2aNNHixYt19913V3rtAACgaqoyz7mpLGlpaQoMDFRqaipzbgAAqCZK8/5drebcAAAAFIdwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALMXL0wUAAFDZnE5T2U5TDqcph2nK4TCV7XTKkb+9wO85y84Cy0X0M01Jks2QDBkyDMmQZBjuv9sM5S7nNNoMI3ddvu0K7MNmu9xHytvH5TZbzgrXdjZbwRry98/5acutS/n2cbUaXLXn1m1coXY/b3sl/om6I9wAQDkwTVNZuW+QTjNn2WlKMiWnacrU5TZTpkxTOS+Zrv5FtTlNSbn987Z1Oq+yD13eV/42Z05Ht+M7C/Qvsi1327x1Oed6uc1935eP5XCaynZcfuN3mnnLzpwg4AoUBQPG5T6Xtym4j3whI19bweXL2xQILbm1o+I0quWrHbNu89jxCTcAqgwz900xy2Eqy+lUVnbOm1xm7s8shzP3lfOGleXIact2OpWZnfOGl5W/3WEWWHYq07WtU1lO8/IxctdnOS4fJ9th5utTeF1mvmNkO6+Hd0tThkzZXC9nvmWnbK71eb9Lhpw5owP5fxqF243cbfP2l/9nwfa87eySvF37MWUYV96H8rcZ+fu612YzJC9DstskuyHZDTN32ZTdkLxspmzKazdkN0zZXP1z+hj5rldeiMrNqHLqckDM7aHc/yQpNxCbrm2K6pfXNy+45jWYrvX59+leQ14YdtXgduzLf4cvH9v9GCVld9aURLgBkE/ev9iLGgJ3G06/wrB5/n/hOszcPo7cfx0X2o9TDqcu/4v6SsfK969xp+k+PO+2ncN0CwKuEOB0Kis7N7QUGTxy+lXiVZa3HPJRlryVLR9ly8fIdlv2VrZqGDnL+dt8jSz3bY1s+Xhd3jbvTdQuZ4E3/dw3fuNyQLAbl8NA/nabnO7r8geKAv3yfjfcli//zH/sy2/47mGkYLvNdObrc3n9dcmU5PB0ERXAUP4kVq6c/kGSXqqYnZcA4Qa4iiyHU+mXspV+KUvpl7KVdjFLafmW0y9lK+1SlttyZr4hcPcAcZXQkDccb14OCVZglyM3FGTJJy8MGNnyyxcU8tp8lC0fW7ZbYPBWtvyMbPnZHK6fvka2fI28nzn9fPP2lS90eJnZ8s7dl5eZ5XrZzWzZc3+vtqrVXw9DMmy5kzFsV1lW4fWl6XvFZaOYfRXRt0T71tX7XudsfoEePT7hBpblcJo6nxs+0vKFj/RLWUq7mLuckbecna9Pliu0XMqqzJGEkrHbDNkNQ3abIW+bKV/DIX+bQ362vDd/h/xt2fIxcgKBr80hXyMrZ1mXw4CPUXi0Im90wku5wSA3IOQFBXu+kJA/KNidWbk/M2U3s2VzZsruzJKhcr5++cfky5thl7x8Jbu3ZPeR7Lm/u9p8c9q9fHLX574Krrd7STav3De70ryM0vex2a99H24veznsI/963uThGYQbVElOp6nzmfnDSP7RkpzRk4KBJW9kJW/5Qmb5jSPX8LGrlp+X6voaqu/rVD1fp+r6OFTH26k6PtkK9HKolle2AuzZ8jNy3vS9nFmyK0tezrwgkBMCbG4/M2VzZsmW76fhzJLNkfPTcGTKcGRKziwZjgzJkSUjO+enHBmSMzv3guW+qroSBYS8QJF/fYH+JdqmlEHF5rlPdgAoX4QblDvTNPVzpsNtJCQt7xbORffRkcvr3G/znM/ILubTDDnzJXyVKT9lyc/IzP09U01yl/1sOcsB9mzV8XaotrdDtb2yVcuerQB7lmrasuRvZMvfyO2vTPmYGfI2M+XlzJSXM0N2xyUZjksysi5JWZekTIeUXllXsgxsXiUICPne9Au2FbeNa31JQkWB9TYv/iUPoFIQblAuTp/P0JJPv9e/vzuszEs/y9vMkF9u2MgJHln5lnN+r6VMNTCyXKHEra9XTmjxNzJVw5aVG0By+vqYGfIxM2UrzVCFU1JG7qu8ePnlvLz98/30zfn9arcxrhgaitrmCuuLChV2H8nGczkBgHCDwhxZ0sVz0sUz0sWz0s+5P4tYdv58Vj+f+0n+GWc1Oy85+FRATcXNtfDyl7z9cn56+RYIHHkhxC9fv6LaCq4rGFoKHINRCACokgg3VuZ0SJdSrxpOCi+fkzLSSnwIm6SAfMumDLfQYFw1NFwpSFwheFwptBA0AAD5EG6qA6dTykjNDSBnpZ/PliysXEpV2T9aYkh+gVKNepJ/Xck/92eNevrJUUOx/83Q7hTpnAIk/7oafUtnDe0ZKrtfILdGAAAeRbipTKYpZaRfIZicvXJQuXROMq/hozC+tXMDSt0iw0qRy36BhT49cvp8hhbGHdQ7O5LkNCUfu00PRYRoUv/WquXnfY0XBwCA8kG4KS+XUqXv1hUfVvI+ulsW3jVzw0ed4sNJ3rJ/nZzJp9cgM9upN7cd0eLPDik9I6f+oZ0aa+bQDmpRv8Y17RsAgPJGuCkvmRekj54oWV+7b24IyQskda8cUFy/182ZW1KJTNNU3L7/6YXYBB05/bMkqWOT2vrj8FD1vqF+pdYCAEBJEW7Ki389qf0dxYcV/7qST9Uf7UhITtO8f+7Tl/89LUlqEOCrGUPa6+4ezWS3MXkXAFB1EW7Ki7efdO/fPF3FNTuVO6/m3bx5NV42PfyLEE0a0EYBvvx1AQBUfbxbQZKUke3Qm9uOaMln/3XNq7njpmA9PfRGNa9X9UeaAADIQ7i5zpmmqU9z59X8mDuvplPT2npmeEfdHFLPw9UBAFB6hJvrWEJymuZ+uE/bf8iZV9OwVu68mu7NZGNeDQCgmiLcXIdOnc/Qgk8Pas3Oy/NqHokI0aP9mVcDAKj+eCe7jmRkO7TyyyNasum/Op83r6ZzsJ6+nXk1AADrINxcB0zT1Mbvc+bVJJ3JmVdzU9NAPTMiVD1bMa8GAGAthBuL+/5Eqp77aJ/+88MZSVKjWr6acfuN+mW3psyrAQBYEuHGon5Kz9DCuAN6d+dRmbnzan4TcYMe7d9aNZlXAwCwMN7lLCYj26EVXx7R0nzzaoZ3znleTbO6zKsBAFgf4cYicubVnNQLsftd82o6NwvUM8NDFca8GgDAdYRwYwHfHc+ZV/NV4uV5Nb+//UbdxbwaAMB1iHBTjaWkX9KCjQf13u6ceTW+Xjb99pYb9Nt+zKsBAFy/eAeshi5l5cyreW3z5Xk1I7o00dNDb1TTOv4erg4AAM8i3FQjpmnqk+9O6oWPE3T0zEVJUpdmOc+r6dGSeTUAAEiEm2rju+OpmvvRPu3InVcTVDtnXs2orsyrAQAgP8JNFZeSfkkvbzygtbuPXZ5X06+1Jva7QTV8+OMDAKAg3h2rqEtZDi3/d6KWbf6vLmQ6JEkjuzbRjNuZVwMAwNUQbqoY0zT18Xcn9UJsgo6dzZ1X07yOnhkeqh4t63q4OgAAqj7CTRXy3fFUzf1wn3YcyZlX07i2n34/tL1GdmFeDQAAJWXzdAHLli1TSEiI/Pz81KNHD23duvWq/d9++2116dJFNWrUUHBwsB588EGdPn26kqqtGClpl/TU2r0asfTf2nHkjPy8bZpya1ttmt5Pd3VrRrABAKAUPBpu1qxZo6lTp2rWrFnas2ePIiIiNHToUCUlJRXZ/9///rfGjRunhx56SN9//73Wrl2rnTt36uGHH67kysvHpSyHXtv8X/V/+XPXhOFRXZto05P99cSgdkwYBgCgDAzTNE1PHbxXr17q3r27oqOjXW0dOnTQqFGjFBUVVaj/yy+/rOjoaB0+fNjVtmTJEr300ks6evRoiY6ZlpamwMBApaamqnbt2td+EmVgmqZiv82ZV3P8XM68mq7N6+iZEaHq3oJ5NQAAFFSa92+PjdxkZmZq9+7dGjx4sFv74MGDtW3btiK36dOnj44dO6bY2FiZpqn//e9/+vvf/6477rjjisfJyMhQWlqa28uTvj2WqtFvbNdjf/tax89dVHCgn14d21XrH+1DsAEAoBx47L7HqVOn5HA4FBQU5NYeFBSkkydPFrlNnz599Pbbb2vMmDG6dOmSsrOzdeedd2rJkiVXPE5UVJTmzJlTrrWXxf/SLulPGw9o3dc5t5/8vG2a2K+1fntLa/n72D1dHgAAluHxCcWG4T5Z1jTNQm159u3bp8mTJ+uZZ57R7t279cknnygxMVETJ0684v5nzpyp1NRU16ukt6/Ky6Ush5ZuOqQBL3+uv+fOq7mrW1Ntnt5fU29rR7ABAKCceWzkpkGDBrLb7YVGaVJSUgqN5uSJiopS37599dRTT0mSOnfurJo1ayoiIkLz5s1TcHBwoW18fX3l6+tb/idQDNM09dE3yZr/8X7XvJpuLXKeV9ON208AAFQYj4UbHx8f9ejRQ3Fxcbrrrrtc7XFxcRo5cmSR2/z888/y8nIv2W7PGfnw4LzoQr45dk5zP9ynXT+elSQ1CfTT74feqDu7NLniqBQAACgfHv2s8bRp0/TAAw8oLCxM4eHh+vOf/6ykpCTXbaaZM2fq+PHjWrVqlSRpxIgReuSRRxQdHa0hQ4YoOTlZU6dO1c0336wmTZp48lQk5cyreemTnHk1kuTvbdej/VvrkYgbuP0EAEAl8Wi4GTNmjE6fPq25c+cqOTlZnTp1UmxsrFq2bClJSk5OdnvmTWRkpNLT07V06VI9+eSTqlOnjgYOHKgXX3zRU6fgsuvIGT2wfIcuZuV8D9QvuzfVjCE3qnGgn4crAwDg+uLR59x4QkU95+ZSlkO3LvhCQbV99cyIjuravE657RsAgOtdad6/eQRuOfHztmvtxHAFB/oxrwYAAA8i3JSjJnX8PV0CAADXPY8/5wYAAKA8EW4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAICleDzcLFu2TCEhIfLz81OPHj20devWq/bPyMjQrFmz1LJlS/n6+qp169aKiYmppGoBAEBV5+XJg69Zs0ZTp07VsmXL1LdvX73xxhsaOnSo9u3bpxYtWhS5zejRo/W///1Py5cvV5s2bZSSkqLs7OxKrhwAAFRVhmmapqcO3qtXL3Xv3l3R0dGutg4dOmjUqFGKiooq1P+TTz7R2LFj9cMPP6hevXplOmZaWpoCAwOVmpqq2rVrl7l2AABQeUrz/u2x21KZmZnavXu3Bg8e7NY+ePBgbdu2rchtNmzYoLCwML300ktq2rSp2rVrp+nTp+vixYtXPE5GRobS0tLcXgAAwLo8dlvq1KlTcjgcCgoKcmsPCgrSyZMni9zmhx9+0L///W/5+fnp/fff16lTpzRp0iSdOXPmivNuoqKiNGfOnHKvHwAAVE0en1BsGIbbsmmahdryOJ1OGYaht99+WzfffLOGDRumhQsXauXKlVccvZk5c6ZSU1Ndr6NHj5b7OQAAgKrDYyM3DRo0kN1uLzRKk5KSUmg0J09wcLCaNm2qwMBAV1uHDh1kmqaOHTumtm3bFtrG19dXvr6+5Vs8AACosjw2cuPj46MePXooLi7OrT0uLk59+vQpcpu+ffvqxIkTOn/+vKvt4MGDstlsatasWYXWCwAAqgeP3paaNm2a/vrXvyomJkYJCQl64oknlJSUpIkTJ0rKuaU0btw4V//77rtP9evX14MPPqh9+/Zpy5YteuqppzRhwgT5+/t76jQAAEAV4tHn3IwZM0anT5/W3LlzlZycrE6dOik2NlYtW7aUJCUnJyspKcnVPyAgQHFxcfrd736nsLAw1a9fX6NHj9a8efM8dQoAAKCK8ehzbjyB59wAAFD9VOhzbhITE8tcGAAAQEUrdbhp06aNBgwYoNWrV+vSpUsVURMAAECZlTrc7N27V926ddOTTz6pxo0b67e//a127NhREbUBAACUWqnDTadOnbRw4UIdP35cK1as0MmTJ/WLX/xCHTt21MKFC/XTTz9VRJ0AAAAlUuaPgnt5eemuu+7Se++9pxdffFGHDx/W9OnT1axZM40bN07JycnlWScAAECJlDnc7Nq1S5MmTVJwcLAWLlyo6dOn6/Dhw9q0aZOOHz+ukSNHlmedAAAAJVLq59wsXLhQK1as0IEDBzRs2DCtWrVKw4YNk82Wk5NCQkL0xhtv6MYbbyz3YgEAAIpT6nATHR2tCRMm6MEHH1Tjxo2L7NOiRQstX778mosDAAAoLR7iBwAAqrwKfYjfihUrtHbt2kLta9eu1Ztvvlna3QEAAJSrUoeb+fPnq0GDBoXaGzVqpBdeeKFcigIAACirUoebH3/8USEhIYXaW7Zs6fYllwAAAJ5Q6nDTqFEjffPNN4Xa9+7dq/r165dLUQAAAGVV6nAzduxYTZ48WZs3b5bD4ZDD4dCmTZs0ZcoUjR07tiJqBAAAKLFSfxR83rx5+vHHH3XrrbfKyytnc6fTqXHjxjHnBgAAeFyZPwp+8OBB7d27V/7+/rrpppvUsmXL8q6tQvBRcAAAqp/SvH+XeuQmT7t27dSuXbuybg4AAFAhyhRujh07pg0bNigpKUmZmZlu6xYuXFguhQEAAJRFqYhZLnsAACAASURBVMPNZ599pjvvvFMhISE6cOCAOnXqpCNHjsg0TXXv3r0iagQAACixUn9aaubMmXryySf13Xffyc/PT+vWrdPRo0fVr18/3XPPPRVRIwAAQImVOtwkJCRo/PjxkiQvLy9dvHhRAQEBmjt3rl588cVyLxAAAKA0Sh1uatasqYyMDElSkyZNdPjwYde6U6dOlV9lAAAAZVDqOTe9e/fWl19+qdDQUN1xxx168skn9e2332r9+vXq3bt3RdQIAABQYqUONwsXLtT58+clSbNnz9b58+e1Zs0atWnTRq+88kq5FwgAAFAapQo3DodDR48eVefOnSVJNWrU0LJlyyqkMAAAgLIo1Zwbu92uIUOG6Ny5cxVVDwAAwDUp9YTim266ST/88ENF1AIAAHDNSh1unn/+eU2fPl0fffSRkpOTlZaW5vYCAADwpFJ/cabNdjkPGYbh+t00TRmGIYfDUX7VVQC+OBMAgOqnQr84c/PmzWUuDAAAoKKVOtz069evIuoAAAAoF6UON1u2bLnq+ltuuaXMxQAAAFyrUoeb/v37F2rLP/emqs+5AQAA1lbqT0udPXvW7ZWSkqJPPvlEPXv21KeffloRNQIAAJRYqUduAgMDC7UNGjRIvr6+euKJJ7R79+5yKQwAAKAsSj1ycyUNGzbUgQMHymt3AAAAZVLqkZtvvvnGbdk0TSUnJ2v+/Pnq0qVLuRUGAABQFqUON127dpVhGCr47L/evXsrJiam3AoDAAAoi1KHm8TERLdlm82mhg0bys/Pr9yKAgAAKKtSh5uWLVtWRB0AAADlotQTiidPnqzFixcXal+6dKmmTp1aLkUBAACUVanDzbp169S3b99C7X369NHf//73cikKAACgrEodbk6fPl3ks25q166tU6dOlUtRAAAAZVXqcNOmTRt98sknhdo//vhj3XDDDeVSFAAAQFmVekLxtGnT9Pjjj+unn37SwIEDJUmfffaZFixYoEWLFpV7gQAAAKVR6nAzYcIEZWRk6Pnnn9dzzz0nSWrVqpWio6M1bty4ci8QAACgNAyz4NP4SuGnn36Sv7+/AgICyrOmCpWWlqbAwEClpqaqdu3ani4HAACUQGnev8v0EL/s7Gy1bdtWDRs2dLUfOnRI3t7eatWqVakLBgAAKC+lnlAcGRmpbdu2FWr/6quvFBkZWR41AQAAlFmpw82ePXuKfM5N7969FR8fXy5FAQAAlFWpw41hGEpPTy/UnpqaKofDUS5FAQAAlFWpw01ERISioqLcgozD4VBUVJR+8YtflGtxAAAApVXqCcUvvfSSbrnlFrVv314RERGSpK1btyo1NVWbN28u9wIBAABKo9QjN6Ghofrmm280evRopaSkKD09XePGjdPBgweVnZ1dETUCAACU2DU950aSzp07p7ffflsxMTGKj4+v8vNueM4NAADVT2nev0s9cpNn06ZNuv/++9WkSRMtXbpUQ4cO1a5du8q6OwAAgHJRqjk3x44d08qVKxUTE6MLFy5o9OjRysrK0rp16xQaGlpRNQIAAJRYiUduhg0bptDQUO3bt09LlizRiRMntGTJkoqsDQAAoNRKPHLz6aefavLkyXr00UfVtm3biqwJAACgzEo8crN161alp6crLCxMvXr10tKlS/XTTz9VZG0AAAClVuJwEx4err/85S9KTk7Wb3/7W7377rtq2rSpnE6n4uLiinxqMQAAQGW7po+CHzhwQMuXL9dbb72lc+fOadCgQdqwYUN51lfu+Cg4AADVT6V8FFyS2rdvr5deeknHjh3TO++8cy27AgAAKBfXFG7y2O12jRo1qkyjNsuWLVNISIj8/PzUo0cPbd26tUTbffnll/Ly8lLXrl1LfUwAAGBd5RJuymrNmjWaOnWqZs2apT179igiIkJDhw5VUlLSVbdLTU3VuHHjdOutt1ZSpQAAoLq45q9fuBa9evVS9+7dFR0d7Wrr0KGDRo0apaioqCtuN3bsWLVt21Z2u10ffPCB4uPjS3xM5twAAFD9VNqcm2uRmZmp3bt3a/DgwW7tgwcP1rZt26643YoVK3T48GE9++yzJTpORkaG0tLS3F4AAMC6PBZuTp06JYfDoaCgILf2oKAgnTx5sshtDh06pKefflpvv/22vLxK9vzBqKgoBQYGul7Nmze/5toBAEDV5dE5N5JkGIbbsmmahdokyeFw6L777tOcOXPUrl27Eu9/5syZSk1Ndb2OHj16zTUDAICqq1RfnFmeGjRoILvdXmiUJiUlpdBojiSlp6dr165d2rNnjx5//HFJktPplGma8vLy0qeffqqBAwcW2s7X11e+vr4VcxIAAKDK8djIjY+Pj3r06KG4uDi39ri4OPXp06dQ/9q1a+vbb79VfHy86zVx4kS1b99e8fHx6tWrV2WVDgAAqjCPjdxI0rRp0/TAAw8oLCxM4eHh+vOf/6ykpCRNnDhRUs4tpePHj2vVqlWy2Wzq1KmT2/aNGjWSn59foXYAAHD98mi4GTNmjE6fPq25c+cqOTlZnTp1UmxsrFq2bClJSk5OLvaZNwAAAPl59Dk3nsBzbgAAqH6qxXNuAAAAKgLhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWIrHw82yZcsUEhIiPz8/9ejRQ1u3br1i3/Xr12vQoEFq2LChateurfDwcG3cuLESqwUAAFWdR8PNmjVrNHXqVM2aNUt79uxRRESEhg4dqqSkpCL7b9myRYMGDVJsbKx2796tAQMGaMSIEdqzZ08lVw4AAKoqwzRN01MH79Wrl7p3767o6GhXW4cOHTRq1ChFRUWVaB8dO3bUmDFj9Mwzz5Sof1pamgIDA5WamqratWuXqW4AAFC5SvP+7bGRm8zMTO3evVuDBw92ax88eLC2bdtWon04nU6lp6erXr16V+yTkZGhtLQ0txcAALAuj4WbU6dOyeFwKCgoyK09KChIJ0+eLNE+FixYoAsXLmj06NFX7BMVFaXAwEDXq3nz5tdUNwAAqNo8PqHYMAy3ZdM0C7UV5Z133tHs2bO1Zs0aNWrU6Ir9Zs6cqdTUVNfr6NGj11wzAACourw8deAGDRrIbrcXGqVJSUkpNJpT0Jo1a/TQQw9p7dq1uu22267a19fXV76+vtdcLwAAqB48NnLj4+OjHj16KC4uzq09Li5Offr0ueJ277zzjiIjI/W3v/1Nd9xxR0WXCQAAqhmPjdxI0rRp0/TAAw8oLCxM4eHh+vOf/6ykpCRNnDhRUs4tpePHj2vVqlWScoLNuHHj9Oqrr6p3796uUR9/f38FBgZ67DwAAEDV4dFwM2bMGJ0+fVpz585VcnKyOnXqpNjYWLVs2VKSlJyc7PbMmzfeeEPZ2dl67LHH9Nhjj7nax48fr5UrV1Z2+QAAoAry6HNuPIHn3AAAUP1Ui+fcAAAAVATCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCDQAAsBQvTxdQVTkcDmVlZXm6DKDc+fj4yGbj3zUArItwU4Bpmjp58qTOnTvn6VKACmGz2RQSEiIfHx9PlwIAFYJwU0BesGnUqJFq1KghwzA8XRJQbpxOp06cOKHk5GS1aNGCv98ALIlwk4/D4XAFm/r163u6HKBCNGzYUCdOnFB2dra8vb09XQ4AlDtuvOeTN8emRo0aHq4EqDh5t6McDoeHKwGAikG4KQJD9bAy/n4DsDrCDYrUqlUrLVq0yNNlAABQasy5sYj+/fura9eu5RZIdu7cqZo1a5bLvgAAqEyEm+uIaZpyOBzy8ir+j71hw4aVUFHlKs35AwCqL25LWUBkZKS++OILvfrqqzIMQ4Zh6MiRI/r8889lGIY2btyosLAw+fr6auvWrTp8+LBGjhypoKAgBQQEqGfPnvrXv/7lts+Ct6UMw9Bf//pX3XXXXapRo4batm2rDRs2XLWu1atXKywsTLVq1VLjxo113333KSUlxa3P999/rzvuuEO1a9dWrVq1FBERocOHD7vWx8TEqGPHjvL19VVwcLAef/xxSdKRI0dkGIbi4+Ndfc+dOyfDMPT5559L0jWdf0ZGhmbMmKHmzZvL19dXbdu21fLly2Waptq0aaOXX37Zrf93330nm83mVjsAwDMIN8UwTVM/Z2Z75GWaZolqfPXVVxUeHq5HHnlEycnJSk5OVvPmzV3rZ8yYoaioKCUkJKhz5846f/68hg0bpn/961/as2ePhgwZohEjRigpKemqx5kzZ45Gjx6tb775RsOGDdOvf/1rnTlz5or9MzMz9dxzz2nv3r364IMPlJiYqMjISNf648eP65ZbbpGfn582bdqk3bt3a8KECcrOzpYkRUdH67HHHtNvfvMbffvtt9qwYYPatGlTomuSX1nOf9y4cXr33Xe1ePFiJSQk6PXXX1dAQIAMw9CECRO0YsUKt2PExMQoIiJCrVu3LnV9AIDyxfh8MS5mORT6zEaPHHvf3CGq4VP8H1FgYKB8fHxUo0YNNW7cuND6uXPnatCgQa7l+vXrq0uXLq7lefPm6f3339eGDRtcIyNFiYyM1L333itJeuGFF7RkyRLt2LFDt99+e5H9J0yY4Pr9hhtu0OLFi3XzzTfr/PnzCggI0GuvvabAwEC9++67ruettGvXzq2uJ598UlOmTHG19ezZs7jLUUhpz//gwYN67733FBcXp9tuu81Vf54HH3xQzzzzjHbs2KGbb75ZWVlZWr16tf70pz+VujYAQPlj5OY6EBYW5rZ84cIFzZgxQ6GhoapTp44CAgK0f//+YkduOnfu7Pq9Zs2aqlWrVqHbTPnt2bNHI0eOVMuWLVWrVi31799fklzHiY+PV0RERJEPkktJSdGJEyd06623lvQ0r6i05x8fHy+73a5+/foVub/g4GDdcccdiomJkSR99NFHunTpku65555rrhUAcO0YuSmGv7dd++YO8dixy0PBTz099dRT2rhxo15++WW1adNG/v7++tWvfqXMzMyr7qdgCDEMQ06ns8i+Fy5c0ODBgzV48GCtXr1aDRs2VFJSkoYMGeI6jr+//xWPdbV1klxf/Jj/1t2Vvui0tOdf3LEl6eGHH9YDDzygV155RStWrNCYMWN4+CMAVBGEm2IYhlGiW0Oe5uPjU+Inzm7dulWRkZG66667JEnnz5/XkSNHyrWe/fv369SpU5o/f75r/s+uXbvc+nTu3FlvvvmmsrKyCgWnWrVqqVWrVvrss880YMCAQvvP+zRXcnKyunXrJkluk4uvprjzv+mmm+R0OvXFF1+4bksVNGzYMNWsWVPR0dH6+OOPtWXLlhIdGwBQ8bgtZRGtWrXSV199pSNHjujUqVNXHFGRpDZt2mj9+vWKj4/X3r17dd999121f1m0aNFCPj4+WrJkiX744Qdt2LBBzz33nFufxx9/XGlpaRo7dqx27dqlQ4cO6a233tKBAwckSbNnz9aCBQu0ePFiHTp0SF9//bWWLFkiKWd0pXfv3po/f7727dunLVu26P/+7/9KVFtx59+qVSuNHz9eEyZMcE2E/vzzz/Xee++5+tjtdkVGRmrmzJlq06aNwsPDr/WSAQDKCeHGIqZPny673a7Q0FDXLaAreeWVV1S3bl316dNHI0aM0JAhQ9S9e/dyradhw4ZauXKl1q5dq9DQUM2fP7/Qx6fr16+vTZs26fz58+rXr5969Oihv/zlL65RnPHjx2vRokVatmyZOnbsqOHDh+vQoUOu7WNiYpSVlaWwsDBNmTJF8+bNK1FtJTn/6Oho/epXv9KkSZN044036pFHHtGFCxfc+jz00EPKzMx0mzgNAPA8wyzp540tIi0tTYGBgUpNTVXt2rXd1l26dEmJiYkKCQmRn5+fhypEdfHll1+qf//+OnbsmIKCgjxdTonx9xxAdXS19++Cqv5kEqCKycjI0NGjR/XHP/5Ro0ePrlbBBgCuB9yWAkrpnXfeUfv27ZWamqqXXnrJ0+UAAAog3AClFBkZKYfDod27d6tp06aeLgcAUADhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBi6tWrXSokWLXMuGYeiDDz64Yv8jR47IMIwSf2FlRe8HAACJJxTjKpKTk1W3bt1y3WdkZKTOnTvnFpqaN2+u5ORkNWjQoFyPBQC4PhFucEWNGzeulOPY7fZKO1ZVk5WV5fqiUABA+eC2lAW88cYbatq0qZxOp1v7nXfeqfHjx0uSDh8+rJEjRyooKEgBAQHq2bOn/vWvf111vwVvS+3YsUPdunWTn5+fwsLCtGfPHrf+DodDDz30kEJCQuTv76/27dvr1Vdfda2fPXu23nzzTf3jH/+QYRgyDEOff/55kbelvvjiC918883y9fVVcHCwnn76aWVnZ7vW9+/fX5MnT9aMGTNUr149NW7cWLNnz77q+ezcuVODBg1SgwYNFBgYqH79+unrr79263Pu3Dn95je/UVBQkPz8/NSpUyd99NFHrvVffvml+vXrpxo1aqhu3boaMmSIzp49K6nwbT1J6tq1q1tdhmHo9ddf18iRI1WzZk3Nmzev2OuWJyYmRh07dnRdk8cff1ySNGHCBA0fPtytb3Z2tho3bqyYmJirXhMAsCJGbopjmlLWz545tncNyTCK7XbPPfdo8uTJ2rx5s2699VZJ0tmzZ7Vx40Z9+OGHkqTz589r2LBhmjdvnvz8/PTmm29qxIgROnDggFq0aFHsMS5cuKDhw4dr4MCBWr16tRITEzVlyhS3Pk6nU82aNdN7772nBg0aaNu2bfrNb36j4OBgjR49WtOnT1dCQoLS0tK0YsUKSVK9evV04sQJt/0cP35cw4YNU2RkpFatWqX9+/frkUcekZ+fn1tQePPNNzVt2jR99dVX2r59uyIjI9W3b18NGjSoyHNIT0/X+PHjtXjxYknSggULNGzYMB06dEi1atWS0+nU0KFDlZ6ertWrV6t169bat2+f7Ha7JCk+Pl633nqrJkyYoMWLF8vLy0ubN2+Ww+Eo9vrl9+yzzyoqKkqvvPKK7HZ7sddNkqKjozVt2jTNnz9fQ4cOVWpqqr788ktJ0sMPP6xbbrlFycnJCg4OliTFxsbq/Pnzru0B4HpCuClO1s/SC008c+w/nJB8ahbbrV69err99tv1t7/9zRVu1q5dq3r16rmWu3Tpoi5duri2mTdvnt5//31t2LDBNQJwNW+//bYcDodiYmJUo0YNdezYUceOHdOjjz7q6uPt7a05c+a4lkNCQrRt2za99957Gj16tAICAuTv76+MjIyr3oZatmyZmjdvrqVLl8owDN144406ceKEfv/73+uZZ56RzZYz4Ni5c2c9++yzkqS2bdtq6dKl+uyzz64YbgYOHOi2/MYbb6hu3br64osvNHz4cP3rX//Sjh07lJCQoHbt2kmSbrjhBlf/l156SWFhYVq2bJmrrWPHjsVeu4Luu+8+TZgwwa3tatdNyvnzevLJJ90CZc+ePSVJffr0Ufv27fXWW29pxowZkqQVK1bonnvuUUBAQKnrA4DqjttSFvHrX/9a69atU0ZGhqScMDJ27FjXqMOFCxc0Y8YMhYaGqk6dOgoICND+/fuVlJRUov0nJCSoS5cuqlGjhqstPDy8UL/XX39dYWFhatiwoQICAvSXv/ylxMfIf6zw8HAZ+Uat+vbtq/Pnz+vYsWOuts6dO7ttFxwcrJSUlCvuNyUlRRMnTlS7du0UGBiowMBAnT9/3lVffHy8mjVr5go2BeWN3FyrsLCwQm1Xu24pKSk6ceLEVY/98MMPu0bDUlJS9M9//rNQgAKA6wUjN8XxrpEzguKpY5fQiBEj5HQ69c9//lM9e/bU1q1btXDhQtf6p556Shs3btTLL7+sNm3ayN/fX7/61a+UmZlZov2bpllsn/fee09PPPGEFixYoPDwcNWqVUt/+tOf9NVXX5X4PPKOZRS4HZd3/PztBSfiGoZRaN5RfpGRkfrpp5+0aNEitWzZUr6+vgoPD3ddA39//6vWVdx6m81W6DplZWUV6lezpvtoXHHXrbjjStK4ceP09NNPa/v27dq+fbtatWqliIiIYrcDACsi3BTHMEp0a8jT/P399ctf/lJvv/22/vvf/6pdu3bq0aOHa/3WrVsVGRmpu+66S1LOHJwjR46UeP+hoaF66623dPHiRdeb7X/+8x+3Plu3blWfPn00adIkV9vhw4fd+vj4+BQ7RyU0NFTr1q1zCznbtm1TrVq11LRp0xLXXNDWrVu1bNkyDRs2TJJ09OhRnTp1yrW+c+fOOnbsmA4ePFjk6E3nzp312Wefud1Cyq9hw4ZKTk52LaelpSkxMbFEdV3tutWqVUutWrXSZ599pgEDBhS5j/r162vUqFFasWKFtm/frgcffLDY4wKAVXFbykJ+/etf65///KdiYmJ0//33u61r06aN1q9fr/j4eO3du1f33XffVUc5Crrvvvtks9n00EMPad++fYqNjdXLL79c6Bi7du3Sxo0bdfDgQf3xj3/Uzp073fq0atVK33zzjQ4cOKBTp04VObIxadIkHT16VL/73e+0f/9+/eMf/9Czzz6radOmuebblEWbNm301ltvKSEhQV999ZV+/etfu42K9OvXT7fccovuvvtuxcXFKTExUR9//LE++eQTSdLMmTO1c+dOTZo0Sd98843279+v6OhoV0AaOHCg3nrrLW3dulXfffedxo8f77otWFxdxV232bNna8GCBVq8eLEOHTqkr7/+WkuWLHHr8/DDD+vNN99UQkKC61NyAHA9ItxYyMCBA1WvXj0dOHBA9913n9u6V155RXXr1lWfPn00YsQIDRkyRN27dy/xvgMCAvThhx9q37596tatm2bNmqUXX3zRrc/EiRP1y1/+UmPGjFGvXr10+vRpt9EISXrkkUfUvn171/ySvE/85Ne0aVPFxsZqx44d6tKliyZOnKiHHnpI//d//1eKq1FYTEyMzp49q27duumBBx7Q5MmT1ahRI7c+69atU8+ePXXvvfcqNDRUM2bMcI00tWvXTp9++qn27t2rm2++WeHh4frHP/4hL6+cAdCZM2fqlltu0fDhwzVs2DCNGjVKrVu3Lraukly38ePHa9GiRVq2bJk6duyo4cOH69ChQ259brvtNgUHB2vIkCFq0sRDk+ABoAowzJJMprCQtLQ0BQYGKjU1VbVr13Zbd+nSJSUmJiokJER+fn4eqhAom59//llNmjRRTEyMfvnLX16xH3/PAVRHV3v/Log5N0A153Q6dfLkSS1YsECBgYG68847PV0SAHgU4Qao5pKSkhQSEqJmzZpp5cqVrttkAHC94v+CQDXXqlWrEn1UHwCuF0woBgAAlkK4AQAAlkK4KQJD/LAy/n4DsDrCTT55j/P/+WcPfQs4UAnyvm6iJA8YBIDqiAnF+djtdtWpU8f15Ys1atQo9B1HQHXmdDr1008/qUaNGnyqCoBl8X+3Aho3bixJV/12aaA6s9lsatGiBcEdgGURbgowDEPBwcFq1KhRkd97BFR3Pj4+1/QdXQBQ1Xk83Cxbtkx/+tOflJycrI4dO2rRokWKiIi4Yv8vvvhC06ZN0/fff68mTZpoxowZmjhxYrnXZbfbmZMAAEA15NF/vq1Zs0ZTp07VrFmztGfPHkVERGjo0KFKSkoqsn9iYqKGDRumiIgI7dmzR3/4wx80efJkrVu3rpIrBwAAVZVHvzizV69e6t69u6Kjo11tHTp00KhRoxQVFVWo/+9//3tt2LBBCQkJrraJEydq79692r59e4mOWZov3gIAAFVDad6/PTZyk5mZqd27d2vw4MFu7YMHD9a2bduK3Gb79u2F+g8ZMkS7du1ifgwAAJDkwTk3p06dksPhUFBQkFt7UFCQTp48WeQ2J0+eLLJ/dna2Tp06peDg4ELbZGRkKCMjw7WcmpoqKScBAgCA6iHvfbskN5w8PqG44MdRTdO86kdUi+pfVHueqKgozZkzp1B78+bNS1sqAADwsPT0dAUGBl61j8fCTYMGDWS32wuN0qSkpBQancnTuHHjIvt7eXmpfv36RW4zc+ZMTZs2zbXsdDp15swZ1a9fv9yf85GWlqbmzZvr6NGj1+V8nuv9/CWuwfV+/hLXgPO/vs9fqrhrYJqm0tPT1aRJk2L7eizc+Pj4qEePHoqLi9Ndd93lao+Li9PIkSOL3CY8PFwffvihW9unn36qsLAw11cnFOTr6ytfX1+3tjp16lxj9VdXu3bt6/YvtcT5S1yD6/38Ja4B5399n79UMdeguBGbPB79KPi0adP017/+VTExMUpISNATTzyhpKQk13NrZs6cqXHjxrn6T5w4UT/++KOmTZumhIQExcTEaPny5Zo+fbqnTgEAAFQxHp1zM2bMGJ0+fVpz585VcnKyOnXqpNjYWLVs2VKSlJyc7PbMm5CQEMXGxuqJJ57Qa6+9piZNLXERBAAAChxJREFUmmjx4sW6++67PXUKAACgivH4hOJJkyZp0qRJRa5buXJlobZ+/frp66+/ruCqysbX11fPPvtsodtg14vr/fwlrsH1fv4S14Dzv77PX6oa18CjD/EDAAAob3x7HgAAsBTCDQAAsBTCDQAAsBTCDQAAsBTCTTlZtmyZQkJC5Ofnpx49emjr1q2eLqnSbNmyRSNGjFCTJk1kGIY++OADT5dUqaKiotSzZ0/VqlVLjRo10qhRo3TgwAFPl1WpoqOj1blzZ9dDu8LDw/Xxxx97uiyPiYqKkmEYmjp1qqdLqTSzZ8+WYRhur8aNG3u6rEp1/Phx3X///apfv75q1Kihrl27avfu3Z4uq9K0atWq0N8BwzD02GOPVXothJtysGbNGk2dOlWzZs3Snj17FBERoaFDh7o9o8fKLly4oC5dumjp0qWeLsUjvvjiCz322GP6z3/+o7i4OGVnZ2vw4MG6cOGCp0urNM2aNdP8+fO1a9cu7dq1SwMHDtTIkSP1/fffe7q0Srdz5079+c9/VufOnT1dSqXr2LGjkpOTXa9vv/3W0yVVmrNnz6pv377y9vbWxx9/rH379mnBggUV/kT8qmTnzp1uf/5xcXGSpHv+v737DWmqbcAAfi2XS9eIqWkb/ZM0zTQpVzIVogxxSVBZVphMJGSlJolQ9Icsor4VBTFYmSQZgpS1iDItEzLCMJYiZoVRQciysNTIwN3Ph3jHu3fP+7zP82bntnX94MB2zrZzHfDDxX3u4715s/JhBP2wFStWCJvN5rMvPj5e7Nu3T1IieQCIxsZG2TGkcrvdAoBoa2uTHUUqvV4vzp8/LzuGooaHh0VsbKxobm4WK1euFOXl5bIjKebw4cMiOTlZdgxp9u7dKzIyMmTHmFTKy8vFggULhMfjUfzcHLn5Qd++fUNnZyeysrJ89mdlZeHhw4eSUpFMnz59AgCEhYVJTiLH+Pg46uvrMTo6CrPZLDuOokpKSpCTk4M1a9bIjiLFixcvYDQaER0dja1bt6K/v192JMU4nU6YTCZs3rwZkZGRWLp0Kc6dOyc7ljTfvn3DpUuXUFRUNOGLVP8dLDc/aHBwEOPj434rmUdFRfmtYE6BTwiBiooKZGRkIDExUXYcRXV3d2P69OnQaDSw2WxobGxEQkKC7FiKqa+vx5MnT3DixAnZUaRITU1FbW0tmpqacO7cOQwMDCAtLQ0fPnyQHU0R/f39sNvtiI2NRVNTE2w2G3bv3o3a2lrZ0aS4du0ahoaGUFhYKOX80pdfCBT/2UyFEFLaKslVWlqKrq4uPHjwQHYUxcXFxcHlcmFoaAhXrlyB1WpFW1vbb1Fw3r59i/Lycty5cwfTpk2THUcKi8XifZ2UlASz2YwFCxbg4sWLqKiokJhMGR6PByaTCcePHwcALF26FD09PbDb7T4LQP8uqqurYbFYYDQapZyfIzc/KCIiAkFBQX6jNG632280hwJbWVkZnE4nWltbMXv2bNlxFBccHIyYmBiYTCacOHECycnJOH36tOxYiujs7ITb7UZKSgrUajXUajXa2tpw5swZqNVqjI+Py46oOK1Wi6SkJLx48UJ2FEUYDAa/Ir9o0aLf5sGSf/f69Wu0tLRgx44d0jKw3Pyg4OBgpKSkeGeF/0tzczPS0tIkpSIlCSFQWlqKq1ev4t69e4iOjpYdaVIQQmBsbEx2DEVkZmaiu7sbLpfLu5lMJuTn58PlciEoKEh2RMWNjY2ht7cXBoNBdhRFpKen+/0LiOfPn2PevHmSEslTU1ODyMhI5OTkSMvA21IToKKiAgUFBTCZTDCbzXA4HHjz5g1sNpvsaIoYGRnBy5cvve9fvXoFl8uFsLAwzJ07V2IyZZSUlODy5cu4fv06dDqddxRvxowZCAkJkZxOGfv374fFYsGcOXMwPDyM+vp63L9/H7dv35YdTRE6nc5vjpVWq0V4ePhvM/eqsrIS69atw9y5c+F2u3Hs2DF8/vwZVqtVdjRF7NmzB2lpaTh+/Djy8vLQ0dEBh8MBh8MhO5qiPB4PampqYLVaoVZLrBiKP58VoM6ePSvmzZsngoODxbJly36rx4BbW1sFAL/NarXKjqaIP7t2AKKmpkZ2NMUUFRV5//5nzpwpMjMzxZ07d2THkup3exR8y5YtwmAwiKlTpwqj0Sg2btwoenp6ZMdS1I0bN0RiYqLQaDQiPj5eOBwO2ZEU19TUJACIvr4+qTlUQgghp1YRERERTTzOuSEiIqKAwnJDREREAYXlhoiIiAIKyw0REREFFJYbIiIiCigsN0RERBRQWG6IiIgooLDcEBHh++K3165dkx2DiCYAyw0RSVdYWAiVSuW3ZWdny45GRL8gri1FRJNCdnY2ampqfPZpNBpJaYjoV8aRGyKaFDQaDWbNmuWz6fV6AN9vGdntdlgsFoSEhCA6OhoNDQ0+3+/u7sbq1asREhKC8PBwFBcXY2RkxOczFy5cwOLFi6HRaGAwGFBaWupzfHBwEBs2bEBoaChiY2PhdDp/7kUT0U/BckNEv4RDhw4hNzcXT58+xfbt27Ft2zb09vYCAL58+YLs7Gzo9Xo8fvwYDQ0NaGlp8SkvdrsdJSUlKC4uRnd3N5xOJ2JiYnzOceTIEeTl5aGrqwtr165Ffn4+Pn78qOh1EtEEkLpsJxGREMJqtYqgoCCh1Wp9tqNHjwohvq+8brPZfL6Tmpoqdu7cKYQQwuFwCL1eL0ZGRrzHb968KaZMmSIGBgaEEEIYjUZx4MCB/5oBgDh48KD3/cjIiFCpVOLWrVsTdp1EpAzOuSGiSWHVqlWw2+0++8LCwryvzWazzzGz2QyXywUA6O3tRXJyMrRarfd4eno6PB4P+vr6oFKp8O7dO2RmZv5lhiVLlnhfa7Va6HQ6uN3u//uaiEgOlhsimhS0Wq3fbaL/RaVSAQCEEN7Xf/aZkJCQv/V7U6dO9fuux+P5R5mISD7OuSGiX8KjR4/83sfHxwMAEhIS4HK5MDo66j3e3t6OKVOmYOHChdDpdJg/fz7u3r2raGYikoMjN0Q0KYyNjWFgYMBnn1qtRkREBACgoaEBJpMJGRkZqKurQ0dHB6qrqwEA+fn5OHz4MKxWK6qqqvD+/XuUlZWhoKAAUVFRAICqqirYbDZERkbCYrFgeHgY7e3tKCsrU/ZCieinY7khoknh9u3bMBgMPvvi4uLw7NkzAN+fZKqvr8euXbswa9Ys1NXVISEhAQAQGhqKpqYmlJeXY/ny5QgNDUVubi5Onjzp/S2r1YqvX7/i1KlTqKysREREBDZt2qTcBRKRYlRCCCE7BBHRX1GpVGhsbMT69etlRyGiXwDn3BAREVFAYbkhIiKigMI5N0Q06fHuORH9Exy5ISIiooDCckNEREQBheWGiIiIAgrLDREREQUUlhsiIiIKKCw3REREFFBYboiIiCigsNwQERFRQGG5ISIiooDyBxmRZHmcsLeuAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(train_a, label='train accuracy')\n",
    "plt.plot(val_a, label='validation accuracy')\n",
    "\n",
    "plt.title('Training history')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()\n",
    "plt.ylim([0, 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbT4YzHFh1s7"
   },
   "source": [
    "Accuracy of Pos/Neg on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 28664,
     "status": "ok",
     "timestamp": 1695329081374,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L1ZzFERMAQk9",
    "outputId": "56411273-fc8c-476f-f660-21239bd321ad",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.8974000000000001\n",
      "F1-Macro:  0.8934379584843231\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "test_acc, _, test_f1 = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "\n",
    "print(\"Accuracy: \",test_acc.item())\n",
    "print(\"F1-Macro: \",test_f1.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695329212064,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "d7DaFLMiAWps",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader):\n",
    "  model = model.eval()\n",
    "\n",
    "  review = []\n",
    "  predictions = []\n",
    "\n",
    "  prediction_probs = []\n",
    "  real_values = []\n",
    "\n",
    "  pos_scores = []\n",
    "  neg_scores = []\n",
    "\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "\n",
    "      reviews = d[\"review\"]\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      sentiments = d[\"sentiments\"].to(device)\n",
    "\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      )\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "\n",
    "      probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "      review.extend(reviews)\n",
    "      predictions.extend(preds)\n",
    "\n",
    "      prediction_probs.extend(probs)\n",
    "      real_values.extend(sentiments)\n",
    "\n",
    "      pos_scores.extend(scores[0])\n",
    "      neg_scores.extend(scores[1])\n",
    "\n",
    "\n",
    "  predictions = torch.stack(predictions).cpu()\n",
    "  prediction_probs = torch.stack(prediction_probs).cpu()\n",
    "  real_values = torch.stack(real_values).cpu()\n",
    "  pos_scores = torch.stack(pos_scores).cpu()\n",
    "  neg_scores = torch.stack(neg_scores).cpu()\n",
    "  return review, predictions, prediction_probs, real_values, pos_scores, neg_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "executionInfo": {
     "elapsed": 27254,
     "status": "ok",
     "timestamp": 1695329239311,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "kZhYtki1AYx8",
    "tags": []
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'scores' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[52], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m y_review_texts, y_pred, y_pred_probs, y_test, pos_scores, neg_scores \u001b[38;5;241m=\u001b[39m \u001b[43mget_predictions\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[43m  \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m      3\u001b[0m \u001b[43m  \u001b[49m\u001b[43mtest_data_loader\u001b[49m\n\u001b[1;32m      4\u001b[0m \u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[51], line 38\u001b[0m, in \u001b[0;36mget_predictions\u001b[0;34m(model, data_loader)\u001b[0m\n\u001b[1;32m     35\u001b[0m     prediction_probs\u001b[38;5;241m.\u001b[39mextend(probs)\n\u001b[1;32m     36\u001b[0m     real_values\u001b[38;5;241m.\u001b[39mextend(sentiments)\n\u001b[0;32m---> 38\u001b[0m     pos_scores\u001b[38;5;241m.\u001b[39mextend(\u001b[43mscores\u001b[49m[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m     39\u001b[0m     neg_scores\u001b[38;5;241m.\u001b[39mextend(scores[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m     42\u001b[0m predictions \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack(predictions)\u001b[38;5;241m.\u001b[39mcpu()\n",
      "\u001b[0;31mNameError\u001b[0m: name 'scores' is not defined"
     ]
    }
   ],
   "source": [
    "y_review_texts, y_pred, y_pred_probs, y_test, pos_scores, neg_scores = get_predictions(\n",
    "  model,\n",
    "  test_data_loader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695329239312,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "mO1FHn0GAbCU",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix\n",
    "class_names = ['negative', 'positive']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PnBSQFJuh_As"
   },
   "source": [
    "Pos/Neg Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695329239312,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "J9-01snpAcM8",
    "outputId": "4d6f7fff-9a14-49bf-c847-d2e8ba79f5f7",
    "tags": []
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'classification_report' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[50], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mclassification_report\u001b[49m(y_test, y_pred, target_names\u001b[38;5;241m=\u001b[39mclass_names))\n",
      "\u001b[0;31mNameError\u001b[0m: name 'classification_report' is not defined"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_test, y_pred, target_names=class_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
