{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 28131,
     "status": "ok",
     "timestamp": 1695497207856,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "F7toc08bpdQ1"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "import textwrap\n",
    "import math\n",
    "from sklearn.model_selection import train_test_split\n",
    "from IPython.display import clear_output\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from transformers import BertModel, BertConfig\n",
    "\n",
    "\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 779,
     "status": "ok",
     "timestamp": 1695497208625,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "MLk4JWizs4S9"
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/train_40k.csv')\n",
    "df_test = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/val_10k.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 36,
     "status": "ok",
     "timestamp": 1695497208626,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "7VLaSHYyzg58"
   },
   "outputs": [],
   "source": [
    "df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "executionInfo": {
     "elapsed": 35,
     "status": "ok",
     "timestamp": 1695497208626,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NqwYxgPz1I-L"
   },
   "outputs": [],
   "source": [
    "df_train = df_train[['Text','Cat1','Cat2','Cat3']]\n",
    "df_val = df_val[['Text','Cat1','Cat2','Cat3']]\n",
    "df_test = df_test[['Text','Cat1','Cat2','Cat3']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 35,
     "status": "ok",
     "timestamp": 1695497208627,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e0NG9J-oq_wW",
    "outputId": "c4b05378-9996-4ec2-dbd0-adc07eb031fd"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-94cfe069-71e8-4797-b1cc-fc5ef9cb93f9\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>beauty</td>\n",
       "      <td>hair care</td>\n",
       "      <td>styling tools</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>novelty gag toys</td>\n",
       "      <td>miniatures</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>grocery gourmet food</td>\n",
       "      <td>fresh flowers live indoor plants</td>\n",
       "      <td>live indoor plants</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24990</th>\n",
       "      <td>Watchmaker offered to install a new battery in...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>household supplies</td>\n",
       "      <td>unknown</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-94cfe069-71e8-4797-b1cc-fc5ef9cb93f9')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-94cfe069-71e8-4797-b1cc-fc5ef9cb93f9 button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-94cfe069-71e8-4797-b1cc-fc5ef9cb93f9');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "\n",
       "<div id=\"df-cfafcc3c-0d38-4d54-b758-47da4798cde1\">\n",
       "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-cfafcc3c-0d38-4d54-b758-47da4798cde1')\"\n",
       "            title=\"Suggest charts.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "  </button>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "      --bg-color: #E8F0FE;\n",
       "      --fill-color: #1967D2;\n",
       "      --hover-bg-color: #E2EBFA;\n",
       "      --hover-fill-color: #174EA6;\n",
       "      --disabled-fill-color: #AAA;\n",
       "      --disabled-bg-color: #DDD;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "      --bg-color: #3B4455;\n",
       "      --fill-color: #D2E3FC;\n",
       "      --hover-bg-color: #434B5C;\n",
       "      --hover-fill-color: #FFFFFF;\n",
       "      --disabled-bg-color: #3B4455;\n",
       "      --disabled-fill-color: #666;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart {\n",
       "    background-color: var(--bg-color);\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: var(--fill-color);\n",
       "    height: 32px;\n",
       "    padding: 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: var(--hover-bg-color);\n",
       "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: var(--button-hover-fill-color);\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart-complete:disabled,\n",
       "  .colab-df-quickchart-complete:disabled:hover {\n",
       "    background-color: var(--disabled-bg-color);\n",
       "    fill: var(--disabled-fill-color);\n",
       "    box-shadow: none;\n",
       "  }\n",
       "\n",
       "  .colab-df-spinner {\n",
       "    border: 2px solid var(--fill-color);\n",
       "    border-color: transparent;\n",
       "    border-bottom-color: var(--fill-color);\n",
       "    animation:\n",
       "      spin 1s steps(1) infinite;\n",
       "  }\n",
       "\n",
       "  @keyframes spin {\n",
       "    0% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "      border-left-color: var(--fill-color);\n",
       "    }\n",
       "    20% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    30% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    40% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    60% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    80% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "    90% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "  }\n",
       "</style>\n",
       "\n",
       "  <script>\n",
       "    async function quickchart(key) {\n",
       "      const quickchartButtonEl =\n",
       "        document.querySelector('#' + key + ' button');\n",
       "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
       "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
       "      try {\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      } catch (error) {\n",
       "        console.error('Error during call to suggestCharts:', error);\n",
       "      }\n",
       "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
       "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
       "    }\n",
       "    (() => {\n",
       "      let quickchartButtonEl =\n",
       "        document.querySelector('#df-cfafcc3c-0d38-4d54-b758-47da4798cde1 button');\n",
       "      quickchartButtonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "    })();\n",
       "  </script>\n",
       "</div>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                                    Text  \\\n",
       "14307  The concept of this toy is good. However, if y...   \n",
       "17812  This dryer ruined my hair!!! At first, after I...   \n",
       "11020  Much to my surprise after a year of waiting th...   \n",
       "15158  The tree is beautiful but upon arrival when I ...   \n",
       "24990  Watchmaker offered to install a new battery in...   \n",
       "\n",
       "                       Cat1                              Cat2  \\\n",
       "14307          pet supplies                              dogs   \n",
       "17812                beauty                         hair care   \n",
       "11020            toys games                  novelty gag toys   \n",
       "15158  grocery gourmet food  fresh flowers live indoor plants   \n",
       "24990  health personal care                household supplies   \n",
       "\n",
       "                     Cat3  \n",
       "14307                toys  \n",
       "17812       styling tools  \n",
       "11020          miniatures  \n",
       "15158  live indoor plants  \n",
       "24990             unknown  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 29,
     "status": "ok",
     "timestamp": 1695497208628,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "ARIavVExPA5X"
   },
   "outputs": [],
   "source": [
    "df_test = df_test[df_test['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n",
    "df_train = df_train[df_train['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n",
    "df_val = df_val[df_val['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "executionInfo": {
     "elapsed": 29,
     "status": "ok",
     "timestamp": 1695497208630,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L5ofeF41UC3N",
    "outputId": "d38ddd54-7ab3-459e-f262-cf7366b4d755"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-cf8027a2-f000-406f-b930-254e26c817be\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>beauty</td>\n",
       "      <td>hair care</td>\n",
       "      <td>styling tools</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>novelty gag toys</td>\n",
       "      <td>miniatures</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>grocery gourmet food</td>\n",
       "      <td>fresh flowers live indoor plants</td>\n",
       "      <td>live indoor plants</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5980</th>\n",
       "      <td>HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>action toy figures</td>\n",
       "      <td>playsets</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>Stays on continuously without shutting off! It...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>health care</td>\n",
       "      <td>pain relievers</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>these look great in our 10 gallon tank- colors...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>fish aquatic pets</td>\n",
       "      <td>aquarium d cor</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>This works great, but needs a better way to at...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>carriers travel products</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>she absolutely LOVES this thing. I dice up gre...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>I hurt my neck and went to rehab. They had me ...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>medical supplies equipment</td>\n",
       "      <td>occupational physical therapy aids</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>47251 rows × 4 columns</p>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-cf8027a2-f000-406f-b930-254e26c817be')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-cf8027a2-f000-406f-b930-254e26c817be button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-cf8027a2-f000-406f-b930-254e26c817be');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "\n",
       "<div id=\"df-b422ab8d-d0e9-4ab2-9bd9-42651ec6ad08\">\n",
       "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-b422ab8d-d0e9-4ab2-9bd9-42651ec6ad08')\"\n",
       "            title=\"Suggest charts.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "  </button>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "      --bg-color: #E8F0FE;\n",
       "      --fill-color: #1967D2;\n",
       "      --hover-bg-color: #E2EBFA;\n",
       "      --hover-fill-color: #174EA6;\n",
       "      --disabled-fill-color: #AAA;\n",
       "      --disabled-bg-color: #DDD;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "      --bg-color: #3B4455;\n",
       "      --fill-color: #D2E3FC;\n",
       "      --hover-bg-color: #434B5C;\n",
       "      --hover-fill-color: #FFFFFF;\n",
       "      --disabled-bg-color: #3B4455;\n",
       "      --disabled-fill-color: #666;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart {\n",
       "    background-color: var(--bg-color);\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: var(--fill-color);\n",
       "    height: 32px;\n",
       "    padding: 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: var(--hover-bg-color);\n",
       "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: var(--button-hover-fill-color);\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart-complete:disabled,\n",
       "  .colab-df-quickchart-complete:disabled:hover {\n",
       "    background-color: var(--disabled-bg-color);\n",
       "    fill: var(--disabled-fill-color);\n",
       "    box-shadow: none;\n",
       "  }\n",
       "\n",
       "  .colab-df-spinner {\n",
       "    border: 2px solid var(--fill-color);\n",
       "    border-color: transparent;\n",
       "    border-bottom-color: var(--fill-color);\n",
       "    animation:\n",
       "      spin 1s steps(1) infinite;\n",
       "  }\n",
       "\n",
       "  @keyframes spin {\n",
       "    0% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "      border-left-color: var(--fill-color);\n",
       "    }\n",
       "    20% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    30% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    40% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    60% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    80% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "    90% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "  }\n",
       "</style>\n",
       "\n",
       "  <script>\n",
       "    async function quickchart(key) {\n",
       "      const quickchartButtonEl =\n",
       "        document.querySelector('#' + key + ' button');\n",
       "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
       "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
       "      try {\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      } catch (error) {\n",
       "        console.error('Error during call to suggestCharts:', error);\n",
       "      }\n",
       "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
       "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
       "    }\n",
       "    (() => {\n",
       "      let quickchartButtonEl =\n",
       "        document.querySelector('#df-b422ab8d-d0e9-4ab2-9bd9-42651ec6ad08 button');\n",
       "      quickchartButtonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "    })();\n",
       "  </script>\n",
       "</div>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                                    Text  \\\n",
       "14307  The concept of this toy is good. However, if y...   \n",
       "17812  This dryer ruined my hair!!! At first, after I...   \n",
       "11020  Much to my surprise after a year of waiting th...   \n",
       "15158  The tree is beautiful but upon arrival when I ...   \n",
       "5980   HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...   \n",
       "...                                                  ...   \n",
       "9995   Stays on continuously without shutting off! It...   \n",
       "9996   these look great in our 10 gallon tank- colors...   \n",
       "9997   This works great, but needs a better way to at...   \n",
       "9998   she absolutely LOVES this thing. I dice up gre...   \n",
       "9999   I hurt my neck and went to rehab. They had me ...   \n",
       "\n",
       "                       Cat1                              Cat2  \\\n",
       "14307          pet supplies                              dogs   \n",
       "17812                beauty                         hair care   \n",
       "11020            toys games                  novelty gag toys   \n",
       "15158  grocery gourmet food  fresh flowers live indoor plants   \n",
       "5980             toys games                action toy figures   \n",
       "...                     ...                               ...   \n",
       "9995   health personal care                       health care   \n",
       "9996           pet supplies                 fish aquatic pets   \n",
       "9997           pet supplies                              dogs   \n",
       "9998           pet supplies                              dogs   \n",
       "9999   health personal care        medical supplies equipment   \n",
       "\n",
       "                                     Cat3  \n",
       "14307                                toys  \n",
       "17812                       styling tools  \n",
       "11020                          miniatures  \n",
       "15158                  live indoor plants  \n",
       "5980                             playsets  \n",
       "...                                   ...  \n",
       "9995                       pain relievers  \n",
       "9996                       aquarium d cor  \n",
       "9997             carriers travel products  \n",
       "9998                                 toys  \n",
       "9999   occupational physical therapy aids  \n",
       "\n",
       "[47251 rows x 4 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.concat([df_train, df_val, df_test], axis=0)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "executionInfo": {
     "elapsed": 404,
     "status": "ok",
     "timestamp": 1695497417800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "enhZGAmPg4gp"
   },
   "outputs": [],
   "source": [
    "# Label Encode Cat1\n",
    "df['Cat1-map'], map = pd.factorize(df['Cat1'])\n",
    "cat1_map = dict(zip(map, range(len(map))))\n",
    "map_cat1 = {v: k for k, v in cat1_map.items()}\n",
    "\n",
    "df_train['Cat1'] = df_train[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "df_val['Cat1'] = df_val[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "df_test['Cat1'] = df_test[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 426,
     "status": "ok",
     "timestamp": 1695497434768,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AHCP1kIPgmSv"
   },
   "outputs": [],
   "source": [
    "PRE_TRAINED_MODEL_NAME = 'bert-base-uncased'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1695497437148,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "xG4nYgPRrY2e"
   },
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "config = BertConfig.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "executionInfo": {
     "elapsed": 7,
     "status": "ok",
     "timestamp": 1695497437149,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UmEtmaFNrakz"
   },
   "outputs": [],
   "source": [
    "MAX_LEN = 200\n",
    "RANDOM_SEED = 42\n",
    "#device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )\n",
    "device = torch.device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 14,
     "status": "ok",
     "timestamp": 1695497437582,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6oWORty0p8Xo",
    "outputId": "a194ad10-d038-432b-d42a-ce9580d47159"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437583,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "OtZt1p7ys7XD"
   },
   "outputs": [],
   "source": [
    "class IMDBDataset(Dataset):\n",
    "\n",
    "  def __init__(self, texts, cats1, tokenizer, max_len):\n",
    "    self.texts = texts\n",
    "    self.cats1 = cats1\n",
    "    self.tokenizer = tokenizer\n",
    "    self.max_len = max_len\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.texts)\n",
    "\n",
    "  def __getitem__(self, item):\n",
    "    text = str(self.texts[item])\n",
    "    cat1 = self.cats1[item]\n",
    "\n",
    "\n",
    "    encoding = self.tokenizer.encode_plus(\n",
    "      text,\n",
    "      add_special_tokens=True,\n",
    "      max_length=self.max_len,\n",
    "      return_token_type_ids=False,\n",
    "      padding='max_length',\n",
    "      truncation = True,\n",
    "      return_attention_mask=True,\n",
    "      return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    return {\n",
    "      'text': text,\n",
    "      'input_ids': encoding['input_ids'].flatten(),\n",
    "      'attention_mask': encoding['attention_mask'].flatten(),\n",
    "      'cat1': torch.tensor(cat1, dtype=torch.long),\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437584,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UuOujQajtL5f"
   },
   "outputs": [],
   "source": [
    "def create_data_loader(df, tokenizer, max_len, batch_size):\n",
    "  ds = IMDBDataset(\n",
    "    texts=df.Text.to_numpy(),\n",
    "    cats1=df['Cat1'].to_numpy(),\n",
    "    tokenizer=tokenizer,\n",
    "    max_len=max_len\n",
    "  )\n",
    "\n",
    "  return DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=8\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437585,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "3zzA4eBytOqj"
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "\n",
    "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 448,
     "status": "ok",
     "timestamp": 1695497438021,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AoUfRPy0tQgk",
    "outputId": "1522a9b5-73c6-4669-eb4d-2c71930b9a1a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['text', 'input_ids', 'attention_mask', 'cat1'])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = next(iter(train_data_loader))\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1695497438022,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Y3Nil-yatUqF"
   },
   "outputs": [],
   "source": [
    "class IMDBClassifier(nn.Module):\n",
    "  def __init__(self, n_classes):\n",
    "    super(IMDBClassifier, self).__init__()\n",
    "    self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME,config=config)\n",
    "\n",
    "    self.FC = nn.Linear(config.hidden_size,6, bias=False)\n",
    "\n",
    "\n",
    "  def forward(self, input_ids, attention_mask):\n",
    "    with torch.no_grad():\n",
    "      pooled_output = self.bert(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_dict = False\n",
    "      )\n",
    "    pooled_output = torch.mean(pooled_output[0], dim=1) # Taking Averge pooled last layer embedding\n",
    "\n",
    "    binary_out = self.FC(pooled_output)\n",
    "\n",
    "    return binary_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "executionInfo": {
     "elapsed": 8266,
     "status": "ok",
     "timestamp": 1695497485319,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWZ37gsztWzL"
   },
   "outputs": [],
   "source": [
    "model = IMDBClassifier(6)\n",
    "model = model.to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695497485321,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWy57v2CxxCM"
   },
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    if name.startswith('bert'):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695497486649,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e5iLu13CYlux"
   },
   "outputs": [],
   "source": [
    "#for name, param in model.named_parameters():\n",
    "#    print(name, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695497487362,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "nZpqz6yDtYZ4",
    "outputId": "c35b8c5d-43f7-49a3-ff8d-a4172979d24e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 200])\n",
      "torch.Size([16, 200])\n"
     ]
    }
   ],
   "source": [
    "input_ids = data['input_ids'].to(device)\n",
    "attention_mask = data['attention_mask'].to(device)\n",
    "\n",
    "\n",
    "print(input_ids.shape) # batch size x seq length\n",
    "print(attention_mask.shape) # batch size x seq length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1695497487363,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NwvA-zp7vqc1"
   },
   "outputs": [],
   "source": [
    "#del test\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "executionInfo": {
     "elapsed": 3357,
     "status": "ok",
     "timestamp": 1695497500843,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "-9Z37OXOtb0q"
   },
   "outputs": [],
   "source": [
    "outs = model(input_ids, attention_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "executionInfo": {
     "elapsed": 410,
     "status": "ok",
     "timestamp": 1695498559469,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cjMVWA5a_6lf"
   },
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=0.001)\n",
    "total_steps = len(train_data_loader) * EPOCHS\n",
    "\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "  optimizer,\n",
    "  num_warmup_steps=math.floor((1./5)*total_steps),\n",
    "  num_training_steps=total_steps\n",
    ")\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss().to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1695497553390,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cLFDb4pzbx9W"
   },
   "outputs": [],
   "source": [
    "def train_epoch(\n",
    "  model,\n",
    "  data_loader,\n",
    "  loss_fn,\n",
    "  optimizer,\n",
    "  device,\n",
    "  scheduler,\n",
    "  n_examples\n",
    "):\n",
    "  model = model.train()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  for d in data_loader:\n",
    "    input_ids = d[\"input_ids\"].to(device)\n",
    "    attention_mask = d[\"attention_mask\"].to(device)\n",
    "    cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "    outputs = model(\n",
    "      input_ids=input_ids,\n",
    "      attention_mask=attention_mask\n",
    "    ).to(device)\n",
    "\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    loss = loss_fn(outputs, cat1)\n",
    "\n",
    "    correct_predictions += torch.sum(preds == cat1)\n",
    "    losses.append(loss.item())\n",
    "\n",
    "\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "executionInfo": {
     "elapsed": 11,
     "status": "ok",
     "timestamp": 1695497553909,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "z4GAdIawtUue"
   },
   "outputs": [],
   "source": [
    "def eval_model(model, data_loader, loss_fn, device, n_examples, on_new=False):\n",
    "  model = model.eval()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      ).to(device)\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      loss = loss_fn(outputs, cat1)\n",
    "\n",
    "      correct_predictions += torch.sum(preds == cat1)\n",
    "      losses.append(loss.item())\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3729011,
     "status": "ok",
     "timestamp": 1695502293865,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "IqdIHJsrANr0",
    "outputId": "883126c6-e286-4c7a-ab04-60c2908e93dc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "----------\n",
      "Train loss 0.9158945335499934 accuracy 0.7044881994166005\n",
      "Val   loss 0.7550778948556521 accuracy 0.7417437252311757\n",
      "\n",
      "Epoch 2/8\n",
      "----------\n",
      "Train loss 0.708059277002278 accuracy 0.7509281357730045\n",
      "Val   loss 0.6500870985320851 accuracy 0.7659180977542932\n",
      "\n",
      "Epoch 3/8\n",
      "----------\n",
      "Train loss 0.6452603197734593 accuracy 0.7684964200477328\n",
      "Val   loss 0.6223940697148882 accuracy 0.7774108322324966\n",
      "\n",
      "Epoch 4/8\n",
      "----------\n",
      "Train loss 0.624240598324311 accuracy 0.7757889154070539\n",
      "Val   loss 0.6097577590542503 accuracy 0.7788639365918097\n",
      "\n",
      "Epoch 5/8\n",
      "----------\n",
      "Train loss 0.6096156636887049 accuracy 0.7794683107928931\n",
      "Val   loss 0.6035292828259086 accuracy 0.7815059445178335\n",
      "\n",
      "Epoch 6/8\n",
      "----------\n",
      "Train loss 0.6004666390089063 accuracy 0.7833465924158048\n",
      "Val   loss 0.5991094943906184 accuracy 0.7829590488771466\n",
      "\n",
      "Epoch 7/8\n",
      "----------\n",
      "Train loss 0.5944655489014543 accuracy 0.787291169451074\n",
      "Val   loss 0.5967095028058889 accuracy 0.7825627476882431\n",
      "\n",
      "Epoch 8/8\n",
      "----------\n",
      "Train loss 0.5909900062639476 accuracy 0.7880535666931848\n",
      "Val   loss 0.5953855391764691 accuracy 0.7841479524438573\n",
      "\n",
      "CPU times: user 56min 42s, sys: 22.1 s, total: 57min 4s\n",
      "Wall time: 1h 2min 9s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_a = []\n",
    "train_l = []\n",
    "val_a = []\n",
    "val_l = []\n",
    "best_accuracy = 0\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "  print(f'Epoch {epoch + 1}/{EPOCHS}')\n",
    "  print('-' * 10)\n",
    "\n",
    "  train_acc, train_loss = train_epoch(\n",
    "    model,\n",
    "    train_data_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device,\n",
    "    scheduler,\n",
    "    len(df_train)\n",
    "  )\n",
    "\n",
    "  print(f'Train loss {train_loss} accuracy {train_acc}')\n",
    "\n",
    "  val_acc, val_loss = eval_model(\n",
    "    model,\n",
    "    val_data_loader,\n",
    "    loss_fn,\n",
    "    device,\n",
    "    len(df_val)\n",
    "  )\n",
    "\n",
    "  print(f'Val   loss {val_loss} accuracy {val_acc}')\n",
    "  print()\n",
    "\n",
    "  train_a.append(train_acc)\n",
    "  train_l.append(train_loss)\n",
    "  val_a.append(val_acc)\n",
    "  val_l.append(val_loss)\n",
    "\n",
    "  if val_acc > best_accuracy:\n",
    "    torch.save(model.state_dict(), 'baseline_bert_best_model_state.bin')\n",
    "    best_accuracy = val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FowMSU5U7SDQ"
   },
   "outputs": [],
   "source": [
    "train_a = [i.item() for i in train_a]\n",
    "train_l = [i.item() for i in train_l]\n",
    "val_a = [i.item() for i in val_a]\n",
    "val_l = [i.item() for i in val_l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 472
    },
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1693911460382,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aUQbxyTEAPhM",
    "outputId": "c5781d52-da67-4451-9200-f67ac20b6c33"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABFTUlEQVR4nO3dfXzN9f/H8efZxTnbsBljRmOucn2VizWS0nwXpXRF+DIqunBRLb+QawqJUq766ot8fV2Vb6RIX63UNymFSRkRoti0ZBuyzTmf3x9ycuzyzLazfTzut9u5tfM+78/n8/p8sPPs/Xl/Ph+LYRiGAAAATMLL0wUAAAAUJcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINgEIZMGCAIiIiCrXsxIkTZbFYiragArrlllvUtGnTfPsdOXJEFotFb775ZvEXBaBIEW4Ak7FYLAV6bdmyxdOlmtL8+fMJRICHWXi2FGAu//73v13e/+tf/9LmzZu1bNkyl/YuXbooNDS00NvJysqSw+GQzWZze9kLFy7owoUL8vPzK/T2C+uWW25RSkqKvvvuuzz7GYahjIwM+fr6ytvbu8Drb9q0qUJCQgiPgAf5eLoAAEXr73//u8v7L7/8Ups3b87WfqVz584pICCgwNvx9fUtVH2S5OPjIx+f0v3rx2KxeCR85eT8+fOyWq3y8mKwHSgI/qUA16BL80527Nihm2++WQEBAXruueckSe+++67uuOMOVa9eXTabTXXr1tWUKVNkt9td1nHlnJtLc1RmzpyphQsXqm7durLZbGrbtq2+/vprl2VzmnNjsVg0dOhQrVu3Tk2bNpXNZlOTJk20adOmbPVv2bJFbdq0kZ+fn+rWrat//OMfbs/j2bt3r2699VYFBASoRo0amjFjhsvnOc25SUpK0sCBA3XdddfJZrMpLCxMd999t44cOSJJioiI0Pfff69PP/3UefrvlltucS5/6NAhPfDAA6pUqZICAgJ04403asOGDdn2zWKxaNWqVRo7dqxq1KihgIAAJSQkyGKx6JVXXsm2L1988YUsFotWrlxZ4P0HzKx0/68TgGLz22+/qWvXrnrwwQf197//3XmK6s0331T58uUVFxen8uXL6+OPP9b48eOVlpaml156Kd/1rlixQunp6Xr00UdlsVg0Y8YM3XvvvTp06FC+oz2ff/653nnnHT3xxBOqUKGCXnvtNd133306evSoKleuLEnatWuXbr/9doWFhWnSpEmy2+2aPHmyqlSpUuB9//3333X77bfr3nvvVc+ePbVmzRqNHDlSzZo1U9euXXNd7r777tP333+vYcOGKSIiQidPntTmzZt19OhRRUREaPbs2Ro2bJjKly+vMWPGSJLzuCYnJ6t9+/Y6d+6chg8frsqVK2vp0qW66667tGbNGt1zzz0u25oyZYqsVqtGjBihjIwMNWzYUB06dNDy5cv19NNPu/Rdvny5KlSooLvvvrvAxwAwNQOAqQ0ZMsS48p96p06dDEnG66+/nq3/uXPnsrU9+uijRkBAgHH+/HlnW2xsrFGrVi3n+8OHDxuSjMqVKxunTp1ytr/77ruGJOO9995ztk2YMCFbTZIMq9VqHDx40Nm2e/duQ5IxZ84cZ1v37t2NgIAA45dffnG2HThwwPDx8cm2zpxc2vd//etfzraMjAyjWrVqxn333Zdtf5YsWWIYhmH8/vvvhiTjpZdeynP9TZo0MTp16pSt/amnnjIkGf/73/+cbenp6Ubt2rWNiIgIw263G4ZhGJ988okhyahTp062P4t//OMfhiQjMTHR2ZaZmWmEhIQYsbGx+e47cK3gtBRwjbLZbBo4cGC2dn9/f+fP6enpSklJUceOHXXu3Dnt27cv3/X26tVLwcHBzvcdO3aUdPGUTH6io6NVt25d5/vmzZsrMDDQuazdbtdHH32kHj16qHr16s5+9erVy3PE5Urly5d3mYNktVrVrl27PGv09/eX1WrVli1b9Pvvvxd4W5ds3LhR7dq100033eRSx+DBg3XkyBHt3bvXpX9sbKzLn4Uk9ezZU35+flq+fLmz7cMPP1RKSkq+c6qAawnhBrhG1ahRQ1arNVv7999/r3vuuUdBQUEKDAxUlSpVnF+cqamp+a63Zs2aLu8vBZ2CBIIrl720/KVlT548qT/++EP16tXL1i+nttxcd9112ebnXL6dnNhsNr344ov64IMPFBoaqptvvlkzZsxQUlJSgbb5008/qUGDBtnaGzVq5Pz8crVr187Wt2LFiurevbtWrFjhbFu+fLlq1Kihzp07F6gO4FpAuAGuUVeOCkjS6dOn1alTJ+3evVuTJ0/We++9p82bN+vFF1+UJDkcjnzXm9tl00YB7jpxNcu6o7Dbeeqpp/TDDz9o2rRp8vPz07hx49SoUSPt2rWrSOuTcv7zkaT+/fvr0KFD+uKLL5Senq7169erd+/eXEkFXIYJxQCctmzZot9++03vvPOObr75Zmf74cOHPVjVX6pWrSo/Pz8dPHgw22c5tRWHunXr6plnntEzzzyjAwcOqGXLlpo1a5bz/kK5XbFVq1Yt7d+/P1v7pVN9tWrVKtD2b7/9dlWpUkXLly9XZGSkzp07p379+hVybwBzIuoDcLo0onH5CEZmZqbmz5/vqZJceHt7Kzo6WuvWrdPx48ed7QcPHtQHH3xQrNs+d+6czp8/79JWt25dVahQQRkZGc62cuXK6fTp09mW79atm7Zv365t27Y5286ePauFCxcqIiJCjRs3LlAdPj4+6t27t9566y29+eabatasmZo3b164nQJMipEbAE7t27dXcHCwYmNjNXz4cFksFi1btqzITwtdjYkTJ+q///2vOnTooMcff1x2u11z585V06ZNlZCQUGzb/eGHH3TbbbepZ8+eaty4sXx8fLR27VolJyfrwQcfdPZr3bq1FixYoOeff1716tVT1apV1blzZ40aNUorV65U165dNXz4cFWqVElLly7V4cOH9Z///Met00r9+/fXa6+9pk8++cR5yhDAXwg3AJwqV66s999/X88884zGjh2r4OBg/f3vf9dtt92mmJgYT5cn6WJ4+OCDDzRixAiNGzdO4eHhmjx5shITEwt0NVdhhYeHq3fv3oqPj9eyZcvk4+Ojhg0b6q233tJ9993n7Dd+/Hj99NNPmjFjhtLT09WpUyd17txZoaGh+uKLLzRy5EjNmTNH58+fV/PmzfXee+/pjjvucKuW1q1bq0mTJkpMTFTfvn2LeleBMo9nSwEwhR49euj777/XgQMHPF1KiWjVqpUqVaqk+Ph4T5cClDrMuQFQ5vzxxx8u7w8cOKCNGze6POrAzL755hslJCSof//+ni4FKJUYuQFQ5oSFhWnAgAGqU6eOfvrpJy1YsEAZGRnatWuX6tev7+nyis13332nHTt2aNasWUpJSdGhQ4dKzcM9gdKEOTcAypzbb79dK1euVFJSkmw2m6KiojR16lRTBxtJWrNmjSZPnqwGDRpo5cqVBBsgFx4dufnss8/00ksvaceOHTpx4oTWrl2rHj165LnMli1bFBcXp++//17h4eEaO3asBgwYUCL1AgCA0s+jc27Onj2rFi1aaN68eQXqf/jwYd1xxx269dZblZCQoKeeekqPPPKIPvzww2KuFAAAlBWlZs6NxWLJd+Rm5MiR2rBhg7777jtn24MPPqjTp09r06ZNJVAlAAAo7crUnJtt27YpOjrapS0mJkZPPfVUrstkZGS43D3U4XDo1KlTqly5cq63SQcAAKWLYRhKT09X9erV873pZZkKN0lJSQoNDXVpCw0NVVpamv74448cHzQ3bdo0TZo0qaRKBAAAxejYsWO67rrr8uxTpsJNYYwePVpxcXHO96mpqapZs6aOHTumwMBAD1YGAAAKKi0tTeHh4apQoUK+fctUuKlWrZqSk5Nd2pKTkxUYGJjjqI0k2Ww22Wy2bO2BgYGEGwAAypiCTCkpU3cojoqKynar8c2bNysqKspDFQEAgNLGo+HmzJkzSkhIcD7J9/Dhw0pISNDRo0clXTyldPntxR977DEdOnRIzz77rPbt26f58+frrbfe0tNPP+2J8gEAQCnk0XDzzTffqFWrVmrVqpUkKS4uTq1atdL48eMlSSdOnHAGHUmqXbu2NmzYoM2bN6tFixaaNWuW/vnPf5aapxUDAADPKzX3uSkpaWlpCgoKUmpqKnNuAJRJdrtdWVlZni4DKHJWqzXXy7zd+f4uUxOKAeBaZhiGkpKSdPr0aU+XAhQLLy8v1a5dW1ar9arWQ7gBgDLiUrCpWrWqAgICuBEpTMXhcOj48eM6ceKEataseVV/vwk3AFAG2O12Z7CpXLmyp8sBikWVKlV0/PhxXbhwQb6+voVeT5m6FBwArlWX5tgEBAR4uBKg+Fw6HWW3269qPYQbAChDOBUFMyuqv9+EGwAAYCqEGwBAmRIREaHZs2d7ugyUYkwoBgAUq1tuuUUtW7YsskDy9ddfq1y5ckWyLpgT4QYA4HGGYchut8vHJ/+vpSpVqpRARSXLnf1H/jgtBQAoNgMGDNCnn36qV199VRaLRRaLRUeOHNGWLVtksVj0wQcfqHXr1rLZbPr888/1448/6u6771ZoaKjKly+vtm3b6qOPPnJZ55WnpSwWi/75z3/qnnvuUUBAgOrXr6/169fnWdeyZcvUpk0bVahQQdWqVVOfPn108uRJlz7ff/+97rzzTgUGBqpChQrq2LGjfvzxR+fnixcvVpMmTWSz2RQWFqahQ4dKko4cOSKLxeJ8bqIknT59WhaLRVu2bJGkq9r/jIwMjRw5UuHh4bLZbKpXr54WLVokwzBUr149zZw506V/QkKCLBaLDh48mOcxMRPCDQCUUYZh6FzmBY+8CvrknldffVVRUVEaNGiQTpw4oRMnTig8PNz5+ahRozR9+nQlJiaqefPmOnPmjLp166b4+Hjt2rVLt99+u7p37+7ynMGcTJo0ST179tS3336rbt26qW/fvjp16lSu/bOysjRlyhTt3r1b69at05EjRzRgwADn57/88otuvvlm2Ww2ffzxx9qxY4ceeughXbhwQZK0YMECDRkyRIMHD9aePXu0fv161atXr0DH5HKF2f/+/ftr5cqVeu2115SYmKh//OMfKl++vCwWix566CEtWbLEZRtLlizRzTffXKj6yirGvwCgjPojy67G4z/0yLb3To5RgDX/r5CgoCBZrVYFBASoWrVq2T6fPHmyunTp4nxfqVIltWjRwvl+ypQpWrt2rdavX+8cGcnJgAED1Lt3b0nS1KlT9dprr2n79u26/fbbc+z/0EMPOX+uU6eOXnvtNbVt21ZnzpxR+fLlNW/ePAUFBWnVqlXOm8ldf/31zmWef/55PfPMM3ryySedbW3bts3vcGTj7v7/8MMPeuutt7R582ZFR0c767/8OIwfP17bt29Xu3btlJWVpRUrVmQbzTE7Rm4AAB7Tpk0bl/dnzpzRiBEj1KhRI1WsWFHly5dXYmJiviM3zZs3d/5crlw5BQYGZjvNdLkdO3aoe/fuqlmzpipUqKBOnTpJknM7CQkJ6tixY453yT158qSOHz+u2267rcD7mRt39z8hIUHe3t7Oeq9UvXp13XHHHVq8eLEk6b333lNGRoYeeOCBq661LGHkBgDKKH9fb+2dHOOxbReFK696GjFihDZv3qyZM2eqXr168vf31/3336/MzMw813NlCLFYLHI4HDn2PXv2rGJiYhQTE6Ply5erSpUqOnr0qGJiYpzb8ff3z3VbeX0myflU68tP3eX2FHd39z+/bUvSI488on79+umVV17RkiVL1KtXr2vuztaEGwAooywWS4FODXma1Wot8O30t27dqgEDBuiee+6RdHEk48iRI0Vaz759+/Tbb79p+vTpzvk/33zzjUuf5s2ba+nSpcrKysoWnCpUqKCIiAjFx8fr1ltvzbb+S1dznThxQq1atZIkl8nFeclv/5s1ayaHw6FPP/3UeVrqSt26dVO5cuW0YMECbdq0SZ999lmBtm0mnJYCABSriIgIffXVVzpy5IhSUlJyHVGRpPr16+udd95RQkKCdu/erT59+uTZvzBq1qwpq9WqOXPm6NChQ1q/fr2mTJni0mfo0KFKS0vTgw8+qG+++UYHDhzQsmXLtH//fknSxIkTNWvWLL322ms6cOCAdu7cqTlz5ki6OLpy4403OicKf/rppxo7dmyBastv/yMiIhQbG6uHHnpI69at0+HDh7Vlyxa99dZbzj7e3t4aMGCARo8erfr16ysqKupqD1mZQ7gBABSrESNGyNvbW40bN3aeAsrNyy+/rODgYLVv317du3dXTEyMbrjhhiKtp0qVKnrzzTf19ttvq3Hjxpo+fXq2CbeVK1fWxx9/rDNnzqhTp05q3bq13njjDecoTmxsrGbPnq358+erSZMmuvPOO3XgwAHn8osXL9aFCxfUunVrPfXUU3r++ecLVFtB9n/BggW6//779cQTT6hhw4YaNGiQzp4969Ln4YcfVmZmpgYOHFiYQ1TmWYyCXs9nEmlpaQoKClJqaqoCAwM9XQ4AFMj58+d1+PBh1a5dW35+fp4uB6Xc//73P9122206duyYQkNDPV1OgeX199yd7+/Sf7IWAAAUSEZGhn799VdNnDhRDzzwQJkKNkWJ01IAAJjEypUrVatWLZ0+fVozZszwdDkeQ7gBAMAkBgwYILvdrh07dqhGjRqeLsdjCDcAAMBUCDcAAMBUCDcAAMBUCDcAAMBUCDcAAMBUCDcAAMBUCDcAgFIvIiJCs2fPdr63WCxat25drv2PHDkii8VS4AdWFvd6ULK4QzEAoMw5ceKEgoODi3SdAwYM0OnTp11CU3h4uE6cOKGQkJAi3RaKF+EGAFDmVKtWrUS24+3tXWLbKm2ysrKcDwotazgtBQAoNgsXLlT16tXlcDhc2u+++2499NBDkqQff/xRd999t0JDQ1W+fHm1bdtWH330UZ7rvfK01Pbt29WqVSv5+fmpTZs22rVrl0t/u92uhx9+WLVr15a/v78aNGigV1991fn5xIkTtXTpUr377ruyWCyyWCzasmVLjqelPv30U7Vr1042m01hYWEaNWqULly44Pz8lltu0fDhw/Xss8+qUqVKqlatmiZOnJjn/nz99dfq0qWLQkJCFBQUpE6dOmnnzp0ufU6fPq1HH31UoaGh8vPzU9OmTfX+++87P9+6datuueUWBQQEKDg4WDExMfr9998lZT+tJ0ktW7Z0qctisWjBggW66667VK5cOb3wwgv5HrdLFi9erCZNmjiPydChQyVJDz30kO68806XvllZWapataoWLVqU5zG5GozcAEBZZRhS1jnPbNs3QLJY8u32wAMPaNiwYfrkk0902223SZJOnTqlTZs2aePGjZKkM2fOqFu3bnrhhRdks9n0r3/9S927d9f+/ftVs2bNfLdx5swZ3XnnnerSpYv+/e9/6/Dhw3ryySdd+jgcDl133XV6++23VblyZX3xxRcaPHiwwsLC1LNnT40YMUKJiYlKS0vTkiVLJEmVKlXS8ePHXdbzyy+/qFu3bhowYID+9a9/ad++fRo0aJD8/PxcgsLSpUsVFxenr776Stu2bdOAAQPUoUMHdenSJcd9SE9PV2xsrObMmSPDMDRr1ix169ZNBw4cUIUKFeRwONS1a1elp6fr3//+t+rWrau9e/fK29tbkpSQkKDbbrtNDz30kF599VX5+Pjok08+kd1uz/f4XW7ixImaPn26Zs+eLR8fn3yPmyQtWLBAcXFxmj59urp27arU1FRt3bpVkvTII4/o5ptv1okTJxQWFiZJev/993Xu3Dn16tXLrdrcQbgBgLIq65w0tbpntv3ccclaLt9uwcHB6tq1q1asWOEMN2vWrFFISIhuvfVWSVKLFi3UokUL5zJTpkzR2rVrtX79eucIQF5WrFghh8OhRYsWyc/PT02aNNHPP/+sxx9/3NnH19dXkyZNcr6vXbu2tm3bprfeeks9e/ZU+fLl5e/vr4yMjDxPQ82fP1/h4eGaO3euLBaLGjZsqOPHj2vkyJEaP368vLwunhBp3ry5JkyYIEmqX7++5s6dq/j4+FzDTefOnV3eL1y4UBUrVtSnn36qO++8Ux999JG2b9+uxMREXX/99ZKkOnXqOPvPmDFDbdq00fz5851tTZo0yffYXalPnz4aOHCgS1tex02Snn/+eT3zzDMugbJt27aSpPbt26tBgwZatmyZnn32WUnSkiVL9MADD6h8+fJu11dQnJYCABSrvn376j//+Y8yMjIkScuXL9eDDz7oDAJnzpzRiBEj1KhRI1WsWFHly5dXYmKijh49WqD1JyYmqnnz5vLz83O2RUVFZes3b948tW7dWlWqVFH58uW1cOHCAm/j8m1FRUXJctmoVYcOHXTmzBn9/PPPzrbmzZu7LBcWFqaTJ0/mut7k5GQNGjRI9evXV1BQkAIDA3XmzBlnfQkJCbruuuucweZKl0ZurlabNm2yteV13E6ePKnjx4/nue1HHnnEORqWnJysDz74wHlKsrgwcgMAZZVvwMURFE9tu4C6d+8uwzC0YcMGtW3bVv/73//0yiuvOD8fMWKENm/erJkzZ6pevXry9/fX/fffr8zMzCIrd9WqVRoxYoRmzZqlqKgoVahQQS+99JK++uqrItvG5a6ciGuxWLLNO7pcbGysfvvtN7366quqVauWbDaboqKinMfA398/z+3l97mXl5cMw3Bpy8rKytavXDnX0bj8jlt+25Wk/v37a9SoUdq2bZu++OIL1a5dWx07dsx3uatBuAGAsspiKdCpIU/z8/PTvffeq+XLl+vgwYNq0KCBbrjhBufnW7du1YABA3TPPfdIujiSc+TIkQKvv1GjRlq2bJnOnz/vHL358ssvXfps3bpV7du31xNPPOFs+/HHH136WK3WfOeoNGrUSP/5z39kGIZz9Gbr1q2qUKGCrrvuugLXfKWtW7dq/vz56tatmyTp2LFjSklJcX7evHlz/fzzz/rhhx9yHL1p3ry54uPjXU4hXa5KlSo6ceKE831aWpoOHz5coLryOm4VKlRQRESE4uPjnacZr1S5cmX16NFDS5Ys0bZt27Kd9ioOnJYCABS7vn37asOGDVq8eLH69u3r8ln9+vX1zjvvKCEhQbt371afPn3yHOW4Up8+fWSxWDRo0CDt3btXGzdu1MyZM7Nt45tvvtGHH36oH374QePGjdPXX3/t0iciIkLffvut9u/fr5SUlBxHNp544gkdO3ZMw4YN0759+/Tuu+9qwoQJiouLc55mK4z69etr2bJlSkxM1FdffaW+ffu6jIp06tRJN998s+677z5t3rxZhw8f1gcffKBNmzZJkkaPHq2vv/5aTzzxhL799lvt27dPCxYscAakzp07a9myZfrf//6nPXv2KDY21jkZOb+68jtuEydO1KxZs/Taa6/pwIED2rlzp+bMmePS55FHHtHSpUuVmJio2NjYQh+ngiLcAACKXefOnVWpUiXt379fffr0cfns5ZdfVnBwsNq3b6/u3bsrJibGZWQnP+XLl9d7772nPXv2qFWrVhozZoxefPFFlz6PPvqo7r33XvXq1UuRkZH67bffXEYjJGnQoEFq0KCB2rRpoypVqjiv+LlcjRo1tHHjRm3fvl0tWrTQY489pocfflhjx45142hkt2jRIv3++++64YYb1K9fPw0fPlxVq1Z16fOf//xHbdu2Ve/evdW4cWM9++yzzpGm66+/Xv/973+1e/dutWvXTlFRUXr33Xfl43PxBM3o0aPVqVMn3XnnnbrjjjvUo0cP1a1bN9+6CnLcYmNjNXv2bM2fP19NmjTRnXfeqQMHDrj0iY6OVlhYmGJiYlS9evFPgrcYV56EM7m0tDQFBQUpNTVVgYGBni4HAArk/PnzOnz4sGrXru0ycRYoC86cOaMaNWpoyZIluvfee3Ptl9ffc3e+v5lzAwAAioXD4VBKSopmzZqlihUr6q677iqR7RJuAABAsTh69Khq166t6667Tm+++abzNFlxI9wAAIBiERERke0S9JLAhGIAAGAqhBsAKEOusWtAcI0pqr/fhBsAKAMu3fH23DkPPSgTKAGX7shckHvw5IU5NwBQBnh7e6tixYrO5xMFBAS4PN8IKOscDod+/fVXBQQEXPXEY8INAJQRl55WndcDGIGyzMvLSzVr1rzq4E64AYAywmKxKCwsTFWrVs3x0QBAWWe1Wq/qMRaXEG4AoIzx9va+6jkJgJkxoRgAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJgK4QYAAJiKx8PNvHnzFBERIT8/P0VGRmr79u159p89e7YaNGggf39/hYeH6+mnn9b58+dLqFoAAFDaeTTcrF69WnFxcZowYYJ27typFi1aKCYmRidPnsyx/4oVKzRq1ChNmDBBiYmJWrRokVavXq3nnnuuhCsHAACllUfDzcsvv6xBgwZp4MCBaty4sV5//XUFBARo8eLFOfb/4osv1KFDB/Xp00cRERH629/+pt69e+c72gMAAK4dHgs3mZmZ2rFjh6Kjo/8qxstL0dHR2rZtW47LtG/fXjt27HCGmUOHDmnjxo3q1q1brtvJyMhQWlqaywsAAJiXj6c2nJKSIrvdrtDQUJf20NBQ7du3L8dl+vTpo5SUFN10000yDEMXLlzQY489ludpqWnTpmnSpElFWjsAACi9PD6h2B1btmzR1KlTNX/+fO3cuVPvvPOONmzYoClTpuS6zOjRo5Wamup8HTt2rAQrBgAAJc1jIzchISHy9vZWcnKyS3tycrKqVauW4zLjxo1Tv3799Mgjj0iSmjVrprNnz2rw4MEaM2aMvLyyZzWbzSabzVb0OwAAAEolj43cWK1WtW7dWvHx8c42h8Oh+Ph4RUVF5bjMuXPnsgUYb29vSZJhGMVXLAAAKDM8NnIjSXFxcYqNjVWbNm3Url07zZ49W2fPntXAgQMlSf3791eNGjU0bdo0SVL37t318ssvq1WrVoqMjNTBgwc1btw4de/e3RlyAADAtc2j4aZXr1769ddfNX78eCUlJally5batGmTc5Lx0aNHXUZqxo4dK4vForFjx+qXX35RlSpV1L17d73wwgue2gUAAFDKWIxr7HxOWlqagoKClJqaqsDAQE+XAwAACsCd7+8ydbUUAABAfgg3AADAVAg3AADAVAg3AADAVAg3AADAVAg3AADAVAg3AADAVAg3AADAVAg3AADAVAg3AADAVAg3AADAVDz64EwAAFACDOPiS5f/15F/m+H4a3mXtos/OxwOZdrtysyyK+PCBWVmOZR54YK8fW2qGVHPQztLuAFQ2hXkl/Jlv2wL+ks5/1/uymfdua1HBVh3buuRmzVeWs+Vnxe07sIe24JsL6e6C1Oj47JjWnRfykVbYwGOY6H/3PKv0cjnz81y8QAWCy9Jfn++Lrfft5E05sti225+CDfFqdD/SK78R+nuL0nlsu781qMCrPsq/nFfc78kC7C9Iv3zz6/GQh6TovpFXtg/NwB5spTw9hzGxbh0MTZZ5JBF0mVtFovs3rYSrsoV4aaoHPtaWhwjfikDpZzFS5JFslj+/K/XZT/n1JbTMpZ81qMCrvvy9aiA676yTbksk9d6VIB68qrRnf3Kq8YCHHs3a7Qbhi44pAuGRRcchrLsUpbDuOznP19248+fjT9/vvh55qX/2i+2Z/7ZN/PP18WfpYzL3p+/cPFnh1y/9A1ZZPwZBByXtenPPpe3XVzO64rlJcdlbbrivctyxp99LBb5+njJ6uMjXx9vWX28ZfXxkq+Pr6w+XrL5+Fx87+stq4+PbL6X+njL6nupj7dsVh9Zfbxl873Yx+brI5uPt/ysPrL5eMnm4yU/X++LP/t6y8/HSz7ef03jbezuv8siRrgpMoZk2It5GwX8peJskwryyyDXZQr8S6Wg9Vy5HhXwF3Ee2yuq/SrSGgt77C1XbM+dL51Ly7r7pfPnF4jbXzpX8+dRgL8zua4nj+OT33ouHSMUO4fDUMYFhzIu2HU+6+J/My44dD7r4n8zsi77OafPLtiVkZXH8hccynD2t+v8Ze8vOErH/1hafbzk9+cXv0sQcPnZW36+f4YJ39z72Hwv/28ey/t4y9fbIgt/1wk3RSaspRSXqGL5hc5fVABuMgwjexC4PCy4hAjXsJFbn+xh469+GZf1y7Q7PL37kiSrt9efIwtXhAQfr79CgcsIxMWAcKmfu30ubcPq7SUvL35vexLhpqj4WKXA6p6uAkApYhiXTlvkPBJx5ShFRi6jFOez8h/JOH/ZSMalttLAx8vicurCdtmpjPxGMnL97FJ7Hn2sPl7yJmBcswg3AEzvgt3hcuoit1CQW58rRyfyHfG4rN0oBWdJLBbJz43gYHNnJMMlaGT/7PJ5GEBJIdwAKBF2h+Fy6qIgp0HyCw4FGvG44JC9lMzDKPjpjeyTNd0Z8bgyePh4MQ8D1xbCDXANcTgMZdoLOmnTfsXEz5zCRs4jHpk59M+yl46AYfX2ynnuhUvYKNjpkNxGPFxHSC72sXp7ETCAEkK4AUqYYfwZMFzmUmQPFfldYZLX3I0cJ31ecCizlMzD8PW2ZDuFYb1itCL/K0z+HNXI8XRI9qtImOgJXDsIN7gmGcbF+17kNmkz38maec3dyOdqk9IyD8PLImcYyH1CZ0HmZ7h3hYnVm3kYAIoX4QYeZXcJGFeEgsIEh9xGPHIY1SgN0zAsFmULDgWel5HTqMYV8zMuH/m48jNfAgYAkyLcwOWGW27fTKsQN9y6/LOyfMOtqxrx4IZbAFBsCDelhLs33CrcFSal+4Zbvt4Wl1GKorrhVu6XsDIPAwDMiHBTRH7+/ZxWf32s4HM3LgWPMnjDrdwma3LDLQBAaUC4KSLJaRma8/HBIlnXpRtuFTQ45HbDrVxPr+R6hQkTPQEAZR/hpohUC/JTbFStIrnFOPMwAAAoPMJNEalR0V+T7m7q6TIAALjmcQ4CAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYCuEGAACYisfDzbx58xQRESE/Pz9FRkZq+/btefY/ffq0hgwZorCwMNlsNl1//fXauHFjCVULAABKOx9Pbnz16tWKi4vT66+/rsjISM2ePVsxMTHav3+/qlatmq1/ZmamunTpoqpVq2rNmjWqUaOGfvrpJ1WsWLHkiwcAAKWSxTAMw1Mbj4yMVNu2bTV37lxJksPhUHh4uIYNG6ZRo0Zl6//666/rpZde0r59++Tr61uobaalpSkoKEipqakKDAy8qvoBAEDJcOf722OnpTIzM7Vjxw5FR0f/VYyXl6Kjo7Vt27Ycl1m/fr2ioqI0ZMgQhYaGqmnTppo6darsdnuu28nIyFBaWprLCwAAmJfHwk1KSorsdrtCQ0Nd2kNDQ5WUlJTjMocOHdKaNWtkt9u1ceNGjRs3TrNmzdLzzz+f63amTZumoKAg5ys8PLxI9wMAAJQuHp9Q7A6Hw6GqVatq4cKFat26tXr16qUxY8bo9ddfz3WZ0aNHKzU11fk6duxYCVYMAABKmscmFIeEhMjb21vJycku7cnJyapWrVqOy4SFhcnX11fe3t7OtkaNGikpKUmZmZmyWq3ZlrHZbLLZbEVbPAAAKLU8NnJjtVrVunVrxcfHO9scDofi4+MVFRWV4zIdOnTQwYMH5XA4nG0//PCDwsLCcgw2AADg2uPR01JxcXF64403tHTpUiUmJurxxx/X2bNnNXDgQElS//79NXr0aGf/xx9/XKdOndKTTz6pH374QRs2bNDUqVM1ZMgQT+0CAAAoZTx6n5tevXrp119/1fjx45WUlKSWLVtq06ZNzknGR48elZfXX/krPDxcH374oZ5++mk1b95cNWrU0JNPPqmRI0d6ahcAAEAp49H73HgC97kBAKDsKRP3uQEAACgOboebiIgITZ48WUePHi2OegAAAK6K2+Hmqaee0jvvvKM6deqoS5cuWrVqlTIyMoqjNgAAALcVKtwkJCRo+/btatSokYYNG6awsDANHTpUO3fuLI4aAQAACuyqJxRnZWVp/vz5GjlypLKystSsWTMNHz5cAwcOlMViKao6iwwTigEAKHvc+f4u9KXgWVlZWrt2rZYsWaLNmzfrxhtv1MMPP6yff/5Zzz33nD766COtWLGisKsHAAAoFLfDzc6dO7VkyRKtXLlSXl5e6t+/v1555RU1bNjQ2eeee+5R27Zti7RQAACAgnA73LRt21ZdunTRggUL1KNHD/n6+mbrU7t2bT344INFUiAAAIA73A43hw4dUq1atfLsU65cOS1ZsqTQRQEAABSW21dLnTx5Ul999VW29q+++krffPNNkRQFAABQWG6HmyFDhujYsWPZ2n/55RceYAkAADzO7XCzd+9e3XDDDdnaW7Vqpb179xZJUQAAAIXldrix2WxKTk7O1n7ixAn5+Hj0IeMAAADuh5u//e1vGj16tFJTU51tp0+f1nPPPacuXboUaXEAAADucnuoZebMmbr55ptVq1YttWrVSpKUkJCg0NBQLVu2rMgLBAAAcIfb4aZGjRr69ttvtXz5cu3evVv+/v4aOHCgevfuneM9bwAAAEpSoSbJlCtXToMHDy7qWgAAAK5aoWcA7927V0ePHlVmZqZL+1133XXVRQEAABRWoe5QfM8992jPnj2yWCy69FDxS08At9vtRVshAACAG9y+WurJJ59U7dq1dfLkSQUEBOj777/XZ599pjZt2mjLli3FUCIAAEDBuT1ys23bNn388ccKCQmRl5eXvLy8dNNNN2natGkaPny4du3aVRx1AgAAFIjbIzd2u10VKlSQJIWEhOj48eOSpFq1amn//v1FWx0AAICb3B65adq0qXbv3q3atWsrMjJSM2bMkNVq1cKFC1WnTp3iqBEAAKDA3A43Y8eO1dmzZyVJkydP1p133qmOHTuqcuXKWr16dZEXCAAA4A6Lcelyp6tw6tQpBQcHO6+YKs3S0tIUFBSk1NRUBQYGerocAABQAO58f7s15yYrK0s+Pj767rvvXNorVapUJoINAAAwP7fCja+vr2rWrMm9bAAAQKnl9tVSY8aM0XPPPadTp04VRz0AAABXxe0JxXPnztXBgwdVvXp11apVS+XKlXP5fOfOnUVWHAAAgLvcDjc9evQohjIAAACKRpFcLVWWcLUUAABlT7FdLQUAAFDauX1aysvLK8/LvrmSCgAAeJLb4Wbt2rUu77OysrRr1y4tXbpUkyZNKrLCAAAACqPI5tysWLFCq1ev1rvvvlsUqys2zLkBAKDs8cicmxtvvFHx8fFFtToAAIBCKZJw88cff+i1115TjRo1imJ1AAAAheb2nJsrH5BpGIbS09MVEBCgf//730VaHAAAgLvcDjevvPKKS7jx8vJSlSpVFBkZqeDg4CItDgAAwF1uh5sBAwYUQxkAAABFw+05N0uWLNHbb7+drf3tt9/W0qVLi6QoAACAwnI73EybNk0hISHZ2qtWraqpU6cWSVEAAACF5Xa4OXr0qGrXrp2tvVatWjp69GiRFAUAAFBYboebqlWr6ttvv83Wvnv3blWuXLlIigIAACgst8NN7969NXz4cH3yySey2+2y2+36+OOP9eSTT+rBBx8sjhoBAAAKzO2rpaZMmaIjR47otttuk4/PxcUdDof69+/PnBsAAOBxhX621IEDB5SQkCB/f381a9ZMtWrVKuraigXPlgIAoOxx5/vb7ZGbS+rXr6/69esXdnEAAIBi4facm/vuu08vvvhitvYZM2bogQceKJKiAAAACsvtcPPZZ5+pW7du2dq7du2qzz77rEiKAgAAKCy3w82ZM2dktVqztfv6+iotLa1IigIAACgst8NNs2bNtHr16mztq1atUuPGjYukKAAAgMJye0LxuHHjdO+99+rHH39U586dJUnx8fFasWKF1qxZU+QFAgAAuMPtcNO9e3etW7dOU6dO1Zo1a+Tv768WLVro448/VqVKlYqjRgAAgAIr9H1uLklLS9PKlSu1aNEi7dixQ3a7vahqKxbc5wYAgLLHne9vt+fcXPLZZ58pNjZW1atX16xZs9S5c2d9+eWXhV0dAABAkXDrtFRSUpLefPNNLVq0SGlpaerZs6cyMjK0bt06JhMDAIBSocAjN927d1eDBg307bffavbs2Tp+/LjmzJlTnLUBAAC4rcAjNx988IGGDx+uxx9/nMcuAACAUqvAIzeff/650tPT1bp1a0VGRmru3LlKSUkpztoAAADcVuBwc+ONN+qNN97QiRMn9Oijj2rVqlWqXr26HA6HNm/erPT09OKsEwAAoECu6lLw/fv3a9GiRVq2bJlOnz6tLl26aP369UVZX5HjUnAAAMqeErkUXJIaNGigGTNm6Oeff9bKlSuvZlUAAABF4qrCzSXe3t7q0aNHoUdt5s2bp4iICPn5+SkyMlLbt28v0HKrVq2SxWJRjx49CrVdAABgPkUSbq7G6tWrFRcXpwkTJmjnzp1q0aKFYmJidPLkyTyXO3LkiEaMGKGOHTuWUKUAAKAs8Hi4efnllzVo0CANHDhQjRs31uuvv66AgAAtXrw412Xsdrv69u2rSZMmqU6dOiVYLQAAKO08Gm4yMzO1Y8cORUdHO9u8vLwUHR2tbdu25brc5MmTVbVqVT388MP5biMjI0NpaWkuLwAAYF4eDTcpKSmy2+0KDQ11aQ8NDVVSUlKOy3z++edatGiR3njjjQJtY9q0aQoKCnK+wsPDr7puAABQenn8tJQ70tPT1a9fP73xxhsKCQkp0DKjR49Wamqq83Xs2LFirhIAAHiSWw/OLGohISHy9vZWcnKyS3tycrKqVauWrf+PP/6oI0eOqHv37s42h8MhSfLx8dH+/ftVt25dl2VsNptsNlsxVA8AAEojj47cWK1WtW7dWvHx8c42h8Oh+Ph4RUVFZevfsGFD7dmzRwkJCc7XXXfdpVtvvVUJCQmccgIAAJ4duZGkuLg4xcbGqk2bNmrXrp1mz56ts2fPauDAgZKk/v37q0aNGpo2bZr8/PzUtGlTl+UrVqwoSdnaAQDAtcnj4aZXr1769ddfNX78eCUlJally5batGmTc5Lx0aNH5eVVpqYGAQAAD7qqZ0uVRTxbCgCAsqfEni0FAABQ2hBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqZSKcDNv3jxFRETIz89PkZGR2r59e65933jjDXXs2FHBwcEKDg5WdHR0nv0BAMC1xePhZvXq1YqLi9OECRO0c+dOtWjRQjExMTp58mSO/bds2aLevXvrk08+0bZt2xQeHq6//e1v+uWXX0q4cgAAUBpZDMMwPFlAZGSk2rZtq7lz50qSHA6HwsPDNWzYMI0aNSrf5e12u4KDgzV37lz1798/3/5paWkKCgpSamqqAgMDr7p+AABQ/Nz5/vboyE1mZqZ27Nih6OhoZ5uXl5eio6O1bdu2Aq3j3LlzysrKUqVKlXL8PCMjQ2lpaS4vAABgXh4NNykpKbLb7QoNDXVpDw0NVVJSUoHWMXLkSFWvXt0lIF1u2rRpCgoKcr7Cw8Ovum4AAFB6eXzOzdWYPn26Vq1apbVr18rPzy/HPqNHj1ZqaqrzdezYsRKuEgAAlCQfT248JCRE3t7eSk5OdmlPTk5WtWrV8lx25syZmj59uj766CM1b9481342m002m61I6gUAAKWfR0durFarWrdurfj4eGebw+FQfHy8oqKicl1uxowZmjJlijZt2qQ2bdqURKkAAKCM8OjIjSTFxcUpNjZWbdq0Ubt27TR79mydPXtWAwcOlCT1799fNWrU0LRp0yRJL774osaPH68VK1YoIiLCOTenfPnyKl++vMf2AwAAlA4eDze9evXSr7/+qvHjxyspKUktW7bUpk2bnJOMjx49Ki+vvwaYFixYoMzMTN1///0u65kwYYImTpxYkqUDAIBSyOP3uSlp3OcGAICyp8zc5wYAAKCoEW4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICpEG4AAICplIpwM2/ePEVERMjPz0+RkZHavn17nv3ffvttNWzYUH5+fmrWrJk2btxYQpUCAIDSzuPhZvXq1YqLi9OECRO0c+dOtWjRQjExMTp58mSO/b/44gv17t1bDz/8sHbt2qUePXqoR48e+u6770q4cgAAUBpZDMMwPFlAZGSk2rZtq7lz50qSHA6HwsPDNWzYMI0aNSpb/169euns2bN6//33nW033nijWrZsqddffz3f7aWlpSkoKEipqakKDAwsuh0BAADFxp3vb4+O3GRmZmrHjh2Kjo52tnl5eSk6Olrbtm3LcZlt27a59JekmJiYXPsDAIBri48nN56SkiK73a7Q0FCX9tDQUO3bty/HZZKSknLsn5SUlGP/jIwMZWRkON+npqZKupgAAQBA2XDpe7sgJ5w8Gm5KwrRp0zRp0qRs7eHh4R6oBgAAXI309HQFBQXl2cej4SYkJETe3t5KTk52aU9OTla1atVyXKZatWpu9R89erTi4uKc7x0Oh06dOqXKlSvLYrFc5R64SktLU3h4uI4dO8Z8nmLEcS4ZHOeSwXEuORzrklFcx9kwDKWnp6t69er59vVouLFarWrdurXi4+PVo0cPSRfDR3x8vIYOHZrjMlFRUYqPj9dTTz3lbNu8ebOioqJy7G+z2WSz2VzaKlasWBTl5yowMJB/OCWA41wyOM4lg+NccjjWJaM4jnN+IzaXePy0VFxcnGJjY9WmTRu1a9dOs2fP1tmzZzVw4EBJUv/+/VWjRg1NmzZNkvTkk0+qU6dOmjVrlu644w6tWrVK33zzjRYuXOjJ3QAAAKWEx8NNr1699Ouvv2r8+PFKSkpSy5YttWnTJuek4aNHj8rL66+Lutq3b68VK1Zo7Nixeu6551S/fn2tW7dOTZs29dQuAACAUsTj4UaShg4dmutpqC1btmRre+CBB/TAAw8Uc1Xus9lsmjBhQrbTYChaHOeSwXEuGRznksOxLhml4Th7/CZ+AAAARcnjj18AAAAoSoQbAABgKoQbAABgKoQbAABgKoQbN82bN08RERHy8/NTZGSktm/fnmf/t99+Ww0bNpSfn5+aNWumjRs3llClZZs7x/mNN95Qx44dFRwcrODgYEVHR+f754KL3P37fMmqVatksVicN99E3tw9zqdPn9aQIUMUFhYmm82m66+/nt8dBeDucZ49e7YaNGggf39/hYeH6+mnn9b58+dLqNqy6bPPPlP37t1VvXp1WSwWrVu3Lt9ltmzZohtuuEE2m0316tXTm2++Wex1ykCBrVq1yrBarcbixYuN77//3hg0aJBRsWJFIzk5Ocf+W7duNby9vY0ZM2YYe/fuNcaOHWv4+voae/bsKeHKyxZ3j3OfPn2MefPmGbt27TISExONAQMGGEFBQcbPP/9cwpWXLe4e50sOHz5s1KhRw+jYsaNx9913l0yxZZi7xzkjI8No06aN0a1bN+Pzzz83Dh8+bGzZssVISEgo4crLFneP8/Llyw2bzWYsX77cOHz4sPHhhx8aYWFhxtNPP13ClZctGzduNMaMGWO88847hiRj7dq1efY/dOiQERAQYMTFxRl79+415syZY3h7exubNm0q1joJN25o166dMWTIEOd7u91uVK9e3Zg2bVqO/Xv27GnccccdLm2RkZHGo48+Wqx1lnXuHucrXbhwwahQoYKxdOnS4irRFApznC9cuGC0b9/e+Oc//2nExsYSbgrA3eO8YMECo06dOkZmZmZJlWgK7h7nIUOGGJ07d3Zpi4uLMzp06FCsdZpJQcLNs88+azRp0sSlrVevXkZMTEwxVmYYnJYqoMzMTO3YsUPR0dHONi8vL0VHR2vbtm05LrNt2zaX/pIUExOTa38U7jhf6dy5c8rKylKlSpWKq8wyr7DHefLkyapataoefvjhkiizzCvMcV6/fr2ioqI0ZMgQhYaGqmnTppo6darsdntJlV3mFOY4t2/fXjt27HCeujp06JA2btyobt26lUjN1wpPfQ+WijsUlwUpKSmy2+3Ox0JcEhoaqn379uW4TFJSUo79k5KSiq3Osq4wx/lKI0eOVPXq1bP9g8JfCnOcP//8cy1atEgJCQklUKE5FOY4Hzp0SB9//LH69u2rjRs36uDBg3riiSeUlZWlCRMmlETZZU5hjnOfPn2UkpKim266SYZh6MKFC3rsscf03HPPlUTJ14zcvgfT0tL0xx9/yN/fv1i2y8gNTGX69OlatWqV1q5dKz8/P0+XYxrp6enq16+f3njjDYWEhHi6HFNzOByqWrWqFi5cqNatW6tXr14aM2aMXn/9dU+XZipbtmzR1KlTNX/+fO3cuVPvvPOONmzYoClTpni6NBQBRm4KKCQkRN7e3kpOTnZpT05OVrVq1XJcplq1am71R+GO8yUzZ87U9OnT9dFHH6l58+bFWWaZ5+5x/vHHH3XkyBF1797d2eZwOCRJPj4+2r9/v+rWrVu8RZdBhfn7HBYWJl9fX3l7ezvbGjVqpKSkJGVmZspqtRZrzWVRYY7zuHHj1K9fPz3yyCOSpGbNmuns2bMaPHiwxowZ4/LAZhRebt+DgYGBxTZqIzFyU2BWq1WtW7dWfHy8s83hcCg+Pl5RUVE5LhMVFeXSX5I2b96ca38U7jhL0owZMzRlyhRt2rRJbdq0KYlSyzR3j3PDhg21Z88eJSQkOF933XWXbr31ViUkJCg8PLwkyy8zCvP3uUOHDjp48KAzPErSDz/8oLCwMIJNLgpznM+dO5ctwFwKlAaPXCwyHvseLNbpyiazatUqw2azGW+++aaxd+9eY/DgwUbFihWNpKQkwzAMo1+/fsaoUaOc/bdu3Wr4+PgYM2fONBITE40JEyZwKXgBuHucp0+fblitVmPNmjXGiRMnnK/09HRP7UKZ4O5xvhJXSxWMu8f56NGjRoUKFYyhQ4ca+/fvN95//32jatWqxvPPP++pXSgT3D3OEyZMMCpUqGCsXLnSOHTokPHf//7XqFu3rtGzZ09P7UKZkJ6ebuzatcvYtWuXIcl4+eWXjV27dhk//fSTYRiGMWrUKKNfv37O/pcuBf+///s/IzEx0Zg3bx6XgpdGc+bMMWrWrGlYrVajXbt2xpdffun8rFOnTkZsbKxL/7feesu4/vrrDavVajRp0sTYsGFDCVdcNrlznGvVqmVIyvaaMGFCyRdexrj79/lyhJuCc/c4f/HFF0ZkZKRhs9mMOnXqGC+88IJx4cKFEq667HHnOGdlZRkTJ0406tata/j5+Rnh4eHGE088Yfz+++8lX3gZ8sknn+T4+/bSsY2NjTU6deqUbZmWLVsaVqvVqFOnjrFkyZJir9NiGIy/AQAA82DODQAAMBXCDQAAMBXCDQAAMBXCDQAAMBXCDQAAMBXCDQAAMBXCDQAAMBXCDYBrnsVi0bp16zxdBoAiQrgB4FEDBgyQxWLJ9rr99ts9XRqAMoqnggPwuNtvv11LlixxabPZbB6qBkBZx8gNAI+z2WyqVq2ayys4OFjSxVNGCxYsUNeuXeXv7686depozZo1Lsvv2bNHnTt3lr+/vypXrqzBgwfrzJkzLn0WL16sJk2ayGazKSwsTEOHDnX5PCUlRffcc48CAgJUv359rV+/vnh3GkCxIdwAKPXGjRun++67T7t371bfvn314IMPKjExUZJ09uxZxcTEKDg4WF9//bXefvttffTRRy7hZcGCBRoyZIgGDx6sPXv2aP369apXr57LNiZNmqSePXvq22+/Vbdu3dS3b1+dOnWqRPcTQBEp9kdzAkAeYmNjDW9vb6NcuXIurxdeeMEwDMOQZDz22GMuy0RGRhqPP/64YRiGsXDhQiM4ONg4c+aM8/MNGzYYXl5eRlJSkmEYhlG9enVjzJgxudYgyRg7dqzz/ZkzZwxJxgcffFBk+wmg5DDnBoDH3XrrrVqwYIFLW6VKlZw/R0VFuXwWFRWlhIQESVJiYqJatGihcuXKOT/v0KGDHA6H9u/fL4vFouPHj+u2227Ls4bmzZs7fy5XrpwCAwN18uTJwu4SAA8i3ADwuHLlymU7TVRU/P39C9TP19fX5b3FYpHD4SiOkgAUM+bcACj1vvzyy2zvGzVqJElq1KiRdu/erbNnzzo/37p1q7y8vNSgQQNVqFBBERERio+PL9GaAXgOIzcAPC4jI0NJSUkubT4+PgoJCZEkvf3222rTpo1uuukmLV++XNu3b9eiRYskSX379tWECRMUGxuriRMn6tdff9WwYcPUr18/hYaGSpImTpyoxx57TFWrVlXXrl2Vnp6urVu3atiwYSW7owBKBOEGgMdt2rRJYWFhLm0NGjTQvn37JF28kmnVqlV64oknFBYWppUrV6px48aSpICAAH344Yd68skn1bZtWwUEBOi+++7Tyy+/7FxXbGyszp8/r1deeUUjRoxQSEiI7r///pLbQQAlymIYhuHpIgAgNxaLRWvXrlWPHj08XQqAMoI5NwAAwFQINwAAwFSYcwOgVOPMOQB3MXIDAABMhXADAABMhXADAABMhXADAABMhXADAABMhXADAABMhXADAABMhXADAABMhXADAABM5f8B+mmaGvF5smcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(train_a, label='train accuracy')\n",
    "plt.plot(val_a, label='validation accuracy')\n",
    "\n",
    "plt.title('Training history')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()\n",
    "plt.ylim([0, 1]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbT4YzHFh1s7"
   },
   "source": [
    "Accuracy of Cat1 on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 115057,
     "status": "ok",
     "timestamp": 1695498550868,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L1ZzFERMAQk9",
    "outputId": "8fc647f5-054b-4149-d279-4bef31d2bbb5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6194680962892883"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_acc, _ = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "test_acc.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "d7DaFLMiAWps"
   },
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader):\n",
    "  model = model.eval()\n",
    "\n",
    "  review = []\n",
    "  predictions = []\n",
    "\n",
    "  prediction_probs = []\n",
    "  real_values = []\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "\n",
    "      texts = d[\"text\"]\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat1 = d[\"cat1\"].to(device)\n",
    "      cat3 = d[\"cat3\"].to(device)\n",
    "\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      )\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "\n",
    "      probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "      review.extend(texts)\n",
    "      predictions.extend(preds)\n",
    "\n",
    "      prediction_probs.extend(probs)\n",
    "      real_values.extend(cat1)\n",
    "\n",
    "\n",
    "  predictions = torch.stack(predictions).cpu()\n",
    "  prediction_probs = torch.stack(prediction_probs).cpu()\n",
    "  real_values = torch.stack(real_values).cpu()\n",
    "\n",
    "  return review, predictions, prediction_probs, real_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kZhYtki1AYx8"
   },
   "outputs": [],
   "source": [
    "y_review_texts, y_pred, y_pred_probs, y_test= get_predictions(\n",
    "  model,\n",
    "  test_data_loader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mO1FHn0GAbCU"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PnBSQFJuh_As"
   },
   "source": [
    "Cat1 Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 650,
     "status": "ok",
     "timestamp": 1693917674551,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "J9-01snpAcM8",
    "outputId": "9540fd2f-e127-46dc-eea1-a5109958ddcf"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                      precision    recall  f1-score   support\n",
      "\n",
      "        pet supplies       0.90      0.68      0.77      1576\n",
      "              beauty       0.84      0.59      0.69      2027\n",
      "          toys games       0.69      0.83      0.75      1533\n",
      "grocery gourmet food       0.67      0.71      0.69       811\n",
      "health personal care       0.62      0.78      0.69      2936\n",
      "       baby products       0.59      0.48      0.53       630\n",
      "\n",
      "            accuracy                           0.70      9513\n",
      "           macro avg       0.72      0.68      0.69      9513\n",
      "        weighted avg       0.73      0.70      0.70      9513\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_test, y_pred, target_names=class_names))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
