{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 28131,
     "status": "ok",
     "timestamp": 1695497207856,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "F7toc08bpdQ1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "import textwrap\n",
    "import math\n",
    "from sklearn.model_selection import train_test_split\n",
    "from IPython.display import clear_output\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from transformers import RobertaModel, RobertaConfig,RobertaTokenizer\n",
    "\n",
    "\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 779,
     "status": "ok",
     "timestamp": 1695497208625,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "MLk4JWizs4S9",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/train_40k.csv')\n",
    "df_test = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/val_10k.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 36,
     "status": "ok",
     "timestamp": 1695497208626,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "7VLaSHYyzg58",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 35,
     "status": "ok",
     "timestamp": 1695497208626,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NqwYxgPz1I-L",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train = df_train[['Text','Cat1','Cat2','Cat3']]\n",
    "df_val = df_val[['Text','Cat1','Cat2','Cat3']]\n",
    "df_test = df_test[['Text','Cat1','Cat2','Cat3']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 35,
     "status": "ok",
     "timestamp": 1695497208627,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e0NG9J-oq_wW",
    "outputId": "c4b05378-9996-4ec2-dbd0-adc07eb031fd",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>beauty</td>\n",
       "      <td>hair care</td>\n",
       "      <td>styling tools</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>novelty gag toys</td>\n",
       "      <td>miniatures</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>grocery gourmet food</td>\n",
       "      <td>fresh flowers live indoor plants</td>\n",
       "      <td>live indoor plants</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24990</th>\n",
       "      <td>Watchmaker offered to install a new battery in...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>household supplies</td>\n",
       "      <td>unknown</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    Text  \\\n",
       "14307  The concept of this toy is good. However, if y...   \n",
       "17812  This dryer ruined my hair!!! At first, after I...   \n",
       "11020  Much to my surprise after a year of waiting th...   \n",
       "15158  The tree is beautiful but upon arrival when I ...   \n",
       "24990  Watchmaker offered to install a new battery in...   \n",
       "\n",
       "                       Cat1                              Cat2  \\\n",
       "14307          pet supplies                              dogs   \n",
       "17812                beauty                         hair care   \n",
       "11020            toys games                  novelty gag toys   \n",
       "15158  grocery gourmet food  fresh flowers live indoor plants   \n",
       "24990  health personal care                household supplies   \n",
       "\n",
       "                     Cat3  \n",
       "14307                toys  \n",
       "17812       styling tools  \n",
       "11020          miniatures  \n",
       "15158  live indoor plants  \n",
       "24990             unknown  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 29,
     "status": "ok",
     "timestamp": 1695497208628,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "ARIavVExPA5X",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_test = df_test[df_test['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n",
    "df_train = df_train[df_train['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n",
    "df_val = df_val[df_val['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "executionInfo": {
     "elapsed": 29,
     "status": "ok",
     "timestamp": 1695497208630,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L5ofeF41UC3N",
    "outputId": "d38ddd54-7ab3-459e-f262-cf7366b4d755",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>beauty</td>\n",
       "      <td>hair care</td>\n",
       "      <td>styling tools</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>novelty gag toys</td>\n",
       "      <td>miniatures</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>grocery gourmet food</td>\n",
       "      <td>fresh flowers live indoor plants</td>\n",
       "      <td>live indoor plants</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5980</th>\n",
       "      <td>HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>action toy figures</td>\n",
       "      <td>playsets</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>Stays on continuously without shutting off! It...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>health care</td>\n",
       "      <td>pain relievers</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>these look great in our 10 gallon tank- colors...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>fish aquatic pets</td>\n",
       "      <td>aquarium d cor</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>This works great, but needs a better way to at...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>carriers travel products</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>she absolutely LOVES this thing. I dice up gre...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>I hurt my neck and went to rehab. They had me ...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>medical supplies equipment</td>\n",
       "      <td>occupational physical therapy aids</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>47251 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    Text  \\\n",
       "14307  The concept of this toy is good. However, if y...   \n",
       "17812  This dryer ruined my hair!!! At first, after I...   \n",
       "11020  Much to my surprise after a year of waiting th...   \n",
       "15158  The tree is beautiful but upon arrival when I ...   \n",
       "5980   HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...   \n",
       "...                                                  ...   \n",
       "9995   Stays on continuously without shutting off! It...   \n",
       "9996   these look great in our 10 gallon tank- colors...   \n",
       "9997   This works great, but needs a better way to at...   \n",
       "9998   she absolutely LOVES this thing. I dice up gre...   \n",
       "9999   I hurt my neck and went to rehab. They had me ...   \n",
       "\n",
       "                       Cat1                              Cat2  \\\n",
       "14307          pet supplies                              dogs   \n",
       "17812                beauty                         hair care   \n",
       "11020            toys games                  novelty gag toys   \n",
       "15158  grocery gourmet food  fresh flowers live indoor plants   \n",
       "5980             toys games                action toy figures   \n",
       "...                     ...                               ...   \n",
       "9995   health personal care                       health care   \n",
       "9996           pet supplies                 fish aquatic pets   \n",
       "9997           pet supplies                              dogs   \n",
       "9998           pet supplies                              dogs   \n",
       "9999   health personal care        medical supplies equipment   \n",
       "\n",
       "                                     Cat3  \n",
       "14307                                toys  \n",
       "17812                       styling tools  \n",
       "11020                          miniatures  \n",
       "15158                  live indoor plants  \n",
       "5980                             playsets  \n",
       "...                                   ...  \n",
       "9995                       pain relievers  \n",
       "9996                       aquarium d cor  \n",
       "9997             carriers travel products  \n",
       "9998                                 toys  \n",
       "9999   occupational physical therapy aids  \n",
       "\n",
       "[47251 rows x 4 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.concat([df_train, df_val, df_test], axis=0)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 404,
     "status": "ok",
     "timestamp": 1695497417800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "enhZGAmPg4gp",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Label Encode Cat1\n",
    "df['Cat1-map'], map = pd.factorize(df['Cat1'])\n",
    "cat1_map = dict(zip(map, range(len(map))))\n",
    "map_cat1 = {v: k for k, v in cat1_map.items()}\n",
    "\n",
    "df_train['Cat1'] = df_train[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "df_val['Cat1'] = df_val[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "df_test['Cat1'] = df_test[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 426,
     "status": "ok",
     "timestamp": 1695497434768,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AHCP1kIPgmSv",
    "tags": []
   },
   "outputs": [],
   "source": [
    "PRE_TRAINED_MODEL_NAME = 'roberta-large'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1695497437148,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "xG4nYgPRrY2e",
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer = RobertaTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "config = RobertaConfig.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 7,
     "status": "ok",
     "timestamp": 1695497437149,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UmEtmaFNrakz",
    "tags": []
   },
   "outputs": [],
   "source": [
    "MAX_LEN = 200\n",
    "RANDOM_SEED = 42\n",
    "#device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )\n",
    "device = torch.device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 14,
     "status": "ok",
     "timestamp": 1695497437582,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6oWORty0p8Xo",
    "outputId": "a194ad10-d038-432b-d42a-ce9580d47159",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437583,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "OtZt1p7ys7XD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class IMDBDataset(Dataset):\n",
    "\n",
    "  def __init__(self, texts, cats1, tokenizer, max_len):\n",
    "    self.texts = texts\n",
    "    self.cats1 = cats1\n",
    "    self.tokenizer = tokenizer\n",
    "    self.max_len = max_len\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.texts)\n",
    "\n",
    "  def __getitem__(self, item):\n",
    "    text = str(self.texts[item])\n",
    "    cat1 = self.cats1[item]\n",
    "\n",
    "\n",
    "    encoding = self.tokenizer.encode_plus(\n",
    "      text,\n",
    "      add_special_tokens=True,\n",
    "      max_length=self.max_len,\n",
    "      return_token_type_ids=False,\n",
    "      padding='max_length',\n",
    "      truncation = True,\n",
    "      return_attention_mask=True,\n",
    "      return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    return {\n",
    "      'text': text,\n",
    "      'input_ids': encoding['input_ids'].flatten(),\n",
    "      'attention_mask': encoding['attention_mask'].flatten(),\n",
    "      'cat1': torch.tensor(cat1, dtype=torch.long),\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437584,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UuOujQajtL5f",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_data_loader(df, tokenizer, max_len, batch_size):\n",
    "  ds = IMDBDataset(\n",
    "    texts=df.Text.to_numpy(),\n",
    "    cats1=df['Cat1'].to_numpy(),\n",
    "    tokenizer=tokenizer,\n",
    "    max_len=max_len\n",
    "  )\n",
    "\n",
    "  return DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=8\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437585,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "3zzA4eBytOqj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "\n",
    "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 448,
     "status": "ok",
     "timestamp": 1695497438021,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AoUfRPy0tQgk",
    "outputId": "1522a9b5-73c6-4669-eb4d-2c71930b9a1a",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['text', 'input_ids', 'attention_mask', 'cat1'])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = next(iter(train_data_loader))\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1695497438022,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Y3Nil-yatUqF",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class IMDBClassifier(nn.Module):\n",
    "  def __init__(self, n_classes):\n",
    "    super(IMDBClassifier, self).__init__()\n",
    "    self.bert = RobertaModel.from_pretrained(PRE_TRAINED_MODEL_NAME,config=config)\n",
    "\n",
    "    self.FC = nn.Linear(config.hidden_size,6, bias=False)\n",
    "\n",
    "\n",
    "  def forward(self, input_ids, attention_mask):\n",
    "    with torch.no_grad():\n",
    "      pooled_output = self.bert(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_dict = False\n",
    "      )\n",
    "    pooled_output = torch.mean(pooled_output[0], dim=1) # Taking Averge pooled last layer embedding\n",
    "\n",
    "    binary_out = self.FC(pooled_output)\n",
    "\n",
    "    return binary_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "executionInfo": {
     "elapsed": 8266,
     "status": "ok",
     "timestamp": 1695497485319,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWZ37gsztWzL",
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = IMDBClassifier(6)\n",
    "model = model.to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695497485321,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWy57v2CxxCM",
    "tags": []
   },
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    if name.startswith('bert'):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695497486649,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e5iLu13CYlux",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#for name, param in model.named_parameters():\n",
    "#    print(name, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695497487362,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "nZpqz6yDtYZ4",
    "outputId": "c35b8c5d-43f7-49a3-ff8d-a4172979d24e",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 200])\n",
      "torch.Size([32, 200])\n"
     ]
    }
   ],
   "source": [
    "input_ids = data['input_ids'].to(device)\n",
    "attention_mask = data['attention_mask'].to(device)\n",
    "\n",
    "\n",
    "print(input_ids.shape) # batch size x seq length\n",
    "print(attention_mask.shape) # batch size x seq length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1695497487363,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NwvA-zp7vqc1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#del test\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "executionInfo": {
     "elapsed": 3357,
     "status": "ok",
     "timestamp": 1695497500843,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "-9Z37OXOtb0q",
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs = model(input_ids, attention_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "executionInfo": {
     "elapsed": 410,
     "status": "ok",
     "timestamp": 1695498559469,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cjMVWA5a_6lf",
    "tags": []
   },
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=0.001)\n",
    "total_steps = len(train_data_loader) * EPOCHS\n",
    "\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "  optimizer,\n",
    "  num_warmup_steps=math.floor((1./5)*total_steps),\n",
    "  num_training_steps=total_steps\n",
    ")\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss().to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1695497553390,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cLFDb4pzbx9W",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_epoch(\n",
    "  model,\n",
    "  data_loader,\n",
    "  loss_fn,\n",
    "  optimizer,\n",
    "  device,\n",
    "  scheduler,\n",
    "  n_examples\n",
    "):\n",
    "  model = model.train()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  for d in data_loader:\n",
    "    input_ids = d[\"input_ids\"].to(device)\n",
    "    attention_mask = d[\"attention_mask\"].to(device)\n",
    "    cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "    outputs = model(\n",
    "      input_ids=input_ids,\n",
    "      attention_mask=attention_mask\n",
    "    ).to(device)\n",
    "\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    loss = loss_fn(outputs, cat1)\n",
    "\n",
    "    correct_predictions += torch.sum(preds == cat1)\n",
    "    losses.append(loss.item())\n",
    "\n",
    "\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 11,
     "status": "ok",
     "timestamp": 1695497553909,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "z4GAdIawtUue",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def eval_model(model, data_loader, loss_fn, device, n_examples, on_new=False):\n",
    "  model = model.eval()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      ).to(device)\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      loss = loss_fn(outputs, cat1)\n",
    "\n",
    "      correct_predictions += torch.sum(preds == cat1)\n",
    "      losses.append(loss.item())\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3729011,
     "status": "ok",
     "timestamp": 1695502293865,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "IqdIHJsrANr0",
    "outputId": "883126c6-e286-4c7a-ab04-60c2908e93dc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "----------\n",
      "Train loss 1.2588807522049639 accuracy 0.5660965261203925\n",
      "Val   loss 0.7849596917377746 accuracy 0.7799207397622192\n",
      "\n",
      "Epoch 2/8\n",
      "----------\n",
      "Train loss 0.6082705444521009 accuracy 0.8059533280297004\n",
      "Val   loss 0.5760448713342852 accuracy 0.8138705416116248\n",
      "\n",
      "Epoch 3/8\n",
      "----------\n",
      "Train loss 0.5146783138699395 accuracy 0.8223614425881729\n",
      "Val   loss 0.5351318836463654 accuracy 0.8225891677675032\n",
      "\n",
      "Epoch 4/8\n",
      "----------\n",
      "Train loss 0.4842343623823983 accuracy 0.8314439140811456\n",
      "Val   loss 0.5197842014867042 accuracy 0.8264200792602377\n",
      "\n",
      "Epoch 5/8\n",
      "----------\n",
      "Train loss 0.4755009430463595 accuracy 0.8335322195704058\n",
      "Val   loss 0.5091119658217651 accuracy 0.8282694848084544\n",
      "\n",
      "Epoch 6/8\n",
      "----------\n",
      "Train loss 0.46548849053499175 accuracy 0.8367806947759215\n",
      "Val   loss 0.5033322947437753 accuracy 0.8278731836195509\n",
      "\n",
      "Epoch 7/8\n",
      "----------\n",
      "Train loss 0.4612147368941049 accuracy 0.8396645452134712\n",
      "Val   loss 0.5004734313311959 accuracy 0.8294583883751651\n",
      "\n",
      "Epoch 8/8\n",
      "----------\n",
      "Train loss 0.4565511149762165 accuracy 0.8399297268629011\n",
      "Val   loss 0.499417490156894 accuracy 0.8309114927344782\n",
      "\n",
      "CPU times: user 3h 58min 59s, sys: 24.4 s, total: 3h 59min 24s\n",
      "Wall time: 3h 59min 25s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_a = []\n",
    "train_l = []\n",
    "val_a = []\n",
    "val_l = []\n",
    "best_accuracy = 0\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "  print(f'Epoch {epoch + 1}/{EPOCHS}')\n",
    "  print('-' * 10)\n",
    "\n",
    "  train_acc, train_loss = train_epoch(\n",
    "    model,\n",
    "    train_data_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device,\n",
    "    scheduler,\n",
    "    len(df_train)\n",
    "  )\n",
    "\n",
    "  print(f'Train loss {train_loss} accuracy {train_acc}')\n",
    "\n",
    "  val_acc, val_loss = eval_model(\n",
    "    model,\n",
    "    val_data_loader,\n",
    "    loss_fn,\n",
    "    device,\n",
    "    len(df_val)\n",
    "  )\n",
    "\n",
    "  print(f'Val   loss {val_loss} accuracy {val_acc}')\n",
    "  print()\n",
    "\n",
    "  train_a.append(train_acc)\n",
    "  train_l.append(train_loss)\n",
    "  val_a.append(val_acc)\n",
    "  val_l.append(val_loss)\n",
    "\n",
    "  if val_acc > best_accuracy:\n",
    "    torch.save(model.state_dict(), 'baseline_roberta_best_model_state.bin')\n",
    "    best_accuracy = val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "id": "FowMSU5U7SDQ"
   },
   "outputs": [],
   "source": [
    "train_a = [i.item() for i in train_a]\n",
    "train_l = [i.item() for i in train_l]\n",
    "val_a = [i.item() for i in val_a]\n",
    "val_l = [i.item() for i in val_l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 472
    },
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1693911460382,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aUQbxyTEAPhM",
    "outputId": "c5781d52-da67-4451-9200-f67ac20b6c33"
   },
   "outputs": [],
   "source": [
    "plt.plot(train_a, label='train accuracy')\n",
    "plt.plot(val_a, label='validation accuracy')\n",
    "\n",
    "plt.title('Training history')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()\n",
    "plt.ylim([0, 1]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbT4YzHFh1s7"
   },
   "source": [
    "Accuracy of Cat1 on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 115057,
     "status": "ok",
     "timestamp": 1695498550868,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L1ZzFERMAQk9",
    "outputId": "8fc647f5-054b-4149-d279-4bef31d2bbb5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.783349101229896"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_acc, _ = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "test_acc.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "id": "d7DaFLMiAWps"
   },
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader):\n",
    "  model = model.eval()\n",
    "\n",
    "  review = []\n",
    "  predictions = []\n",
    "\n",
    "  prediction_probs = []\n",
    "  real_values = []\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "\n",
    "      texts = d[\"text\"]\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      )\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "\n",
    "      probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "      review.extend(texts)\n",
    "      predictions.extend(preds)\n",
    "\n",
    "      prediction_probs.extend(probs)\n",
    "      real_values.extend(cat1)\n",
    "\n",
    "\n",
    "  predictions = torch.stack(predictions).cpu()\n",
    "  prediction_probs = torch.stack(prediction_probs).cpu()\n",
    "  real_values = torch.stack(real_values).cpu()\n",
    "\n",
    "  return review, predictions, prediction_probs, real_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "id": "kZhYtki1AYx8"
   },
   "outputs": [],
   "source": [
    "y_review_texts, y_pred, y_pred_probs, y_test= get_predictions(\n",
    "  model,\n",
    "  test_data_loader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "id": "mO1FHn0GAbCU"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PnBSQFJuh_As"
   },
   "source": [
    "Cat1 Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 650,
     "status": "ok",
     "timestamp": 1693917674551,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "J9-01snpAcM8",
    "outputId": "9540fd2f-e127-46dc-eea1-a5109958ddcf",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                      precision    recall  f1-score   support\n",
      "\n",
      "        pet supplies       0.92      0.80      0.85      1576\n",
      "              beauty       0.83      0.73      0.78      2027\n",
      "          toys games       0.80      0.84      0.82      1533\n",
      "grocery gourmet food       0.77      0.75      0.76       811\n",
      "health personal care       0.72      0.81      0.76      2936\n",
      "       baby products       0.66      0.69      0.68       630\n",
      "\n",
      "            accuracy                           0.78      9513\n",
      "           macro avg       0.78      0.77      0.78      9513\n",
      "        weighted avg       0.79      0.78      0.78      9513\n",
      "\n"
     ]
    }
   ],
   "source": [
    "class_names = list(cat1_map.keys())\n",
    "print(classification_report(y_test, y_pred, target_names=class_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
