{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 28131,
     "status": "ok",
     "timestamp": 1695497207856,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "F7toc08bpdQ1"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "import textwrap\n",
    "import math\n",
    "from sklearn.model_selection import train_test_split\n",
    "from IPython.display import clear_output\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from transformers import T5EncoderModel, T5Config, T5Tokenizer\n",
    "\n",
    "\n",
    "\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 779,
     "status": "ok",
     "timestamp": 1695497208625,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "MLk4JWizs4S9"
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/train_40k.csv')\n",
    "df_test = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/val_10k.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 36,
     "status": "ok",
     "timestamp": 1695497208626,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "7VLaSHYyzg58"
   },
   "outputs": [],
   "source": [
    "df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 35,
     "status": "ok",
     "timestamp": 1695497208626,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NqwYxgPz1I-L"
   },
   "outputs": [],
   "source": [
    "df_train = df_train[['Text','Cat1','Cat2','Cat3']]\n",
    "df_val = df_val[['Text','Cat1','Cat2','Cat3']]\n",
    "df_test = df_test[['Text','Cat1','Cat2','Cat3']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 35,
     "status": "ok",
     "timestamp": 1695497208627,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e0NG9J-oq_wW",
    "outputId": "c4b05378-9996-4ec2-dbd0-adc07eb031fd"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>beauty</td>\n",
       "      <td>hair care</td>\n",
       "      <td>styling tools</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>novelty gag toys</td>\n",
       "      <td>miniatures</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>grocery gourmet food</td>\n",
       "      <td>fresh flowers live indoor plants</td>\n",
       "      <td>live indoor plants</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24990</th>\n",
       "      <td>Watchmaker offered to install a new battery in...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>household supplies</td>\n",
       "      <td>unknown</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    Text  \\\n",
       "14307  The concept of this toy is good. However, if y...   \n",
       "17812  This dryer ruined my hair!!! At first, after I...   \n",
       "11020  Much to my surprise after a year of waiting th...   \n",
       "15158  The tree is beautiful but upon arrival when I ...   \n",
       "24990  Watchmaker offered to install a new battery in...   \n",
       "\n",
       "                       Cat1                              Cat2  \\\n",
       "14307          pet supplies                              dogs   \n",
       "17812                beauty                         hair care   \n",
       "11020            toys games                  novelty gag toys   \n",
       "15158  grocery gourmet food  fresh flowers live indoor plants   \n",
       "24990  health personal care                household supplies   \n",
       "\n",
       "                     Cat3  \n",
       "14307                toys  \n",
       "17812       styling tools  \n",
       "11020          miniatures  \n",
       "15158  live indoor plants  \n",
       "24990             unknown  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 29,
     "status": "ok",
     "timestamp": 1695497208628,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "ARIavVExPA5X"
   },
   "outputs": [],
   "source": [
    "df_test = df_test[df_test['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n",
    "df_train = df_train[df_train['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n",
    "df_val = df_val[df_val['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "executionInfo": {
     "elapsed": 29,
     "status": "ok",
     "timestamp": 1695497208630,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L5ofeF41UC3N",
    "outputId": "d38ddd54-7ab3-459e-f262-cf7366b4d755"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>beauty</td>\n",
       "      <td>hair care</td>\n",
       "      <td>styling tools</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>novelty gag toys</td>\n",
       "      <td>miniatures</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>grocery gourmet food</td>\n",
       "      <td>fresh flowers live indoor plants</td>\n",
       "      <td>live indoor plants</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5980</th>\n",
       "      <td>HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>action toy figures</td>\n",
       "      <td>playsets</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>Stays on continuously without shutting off! It...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>health care</td>\n",
       "      <td>pain relievers</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>these look great in our 10 gallon tank- colors...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>fish aquatic pets</td>\n",
       "      <td>aquarium d cor</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>This works great, but needs a better way to at...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>carriers travel products</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>she absolutely LOVES this thing. I dice up gre...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>I hurt my neck and went to rehab. They had me ...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>medical supplies equipment</td>\n",
       "      <td>occupational physical therapy aids</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>47251 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    Text  \\\n",
       "14307  The concept of this toy is good. However, if y...   \n",
       "17812  This dryer ruined my hair!!! At first, after I...   \n",
       "11020  Much to my surprise after a year of waiting th...   \n",
       "15158  The tree is beautiful but upon arrival when I ...   \n",
       "5980   HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...   \n",
       "...                                                  ...   \n",
       "9995   Stays on continuously without shutting off! It...   \n",
       "9996   these look great in our 10 gallon tank- colors...   \n",
       "9997   This works great, but needs a better way to at...   \n",
       "9998   she absolutely LOVES this thing. I dice up gre...   \n",
       "9999   I hurt my neck and went to rehab. They had me ...   \n",
       "\n",
       "                       Cat1                              Cat2  \\\n",
       "14307          pet supplies                              dogs   \n",
       "17812                beauty                         hair care   \n",
       "11020            toys games                  novelty gag toys   \n",
       "15158  grocery gourmet food  fresh flowers live indoor plants   \n",
       "5980             toys games                action toy figures   \n",
       "...                     ...                               ...   \n",
       "9995   health personal care                       health care   \n",
       "9996           pet supplies                 fish aquatic pets   \n",
       "9997           pet supplies                              dogs   \n",
       "9998           pet supplies                              dogs   \n",
       "9999   health personal care        medical supplies equipment   \n",
       "\n",
       "                                     Cat3  \n",
       "14307                                toys  \n",
       "17812                       styling tools  \n",
       "11020                          miniatures  \n",
       "15158                  live indoor plants  \n",
       "5980                             playsets  \n",
       "...                                   ...  \n",
       "9995                       pain relievers  \n",
       "9996                       aquarium d cor  \n",
       "9997             carriers travel products  \n",
       "9998                                 toys  \n",
       "9999   occupational physical therapy aids  \n",
       "\n",
       "[47251 rows x 4 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.concat([df_train, df_val, df_test], axis=0)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 404,
     "status": "ok",
     "timestamp": 1695497417800,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "enhZGAmPg4gp"
   },
   "outputs": [],
   "source": [
    "# Label Encode Cat1\n",
    "df['Cat1-map'], map = pd.factorize(df['Cat1'])\n",
    "cat1_map = dict(zip(map, range(len(map))))\n",
    "map_cat1 = {v: k for k, v in cat1_map.items()}\n",
    "\n",
    "df_train['Cat1'] = df_train[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "df_val['Cat1'] = df_val[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "df_test['Cat1'] = df_test[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 426,
     "status": "ok",
     "timestamp": 1695497434768,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AHCP1kIPgmSv"
   },
   "outputs": [],
   "source": [
    "PRE_TRAINED_MODEL_NAME = 't5-large'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1695497437148,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "xG4nYgPRrY2e"
   },
   "outputs": [],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "config = T5Config.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 7,
     "status": "ok",
     "timestamp": 1695497437149,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UmEtmaFNrakz"
   },
   "outputs": [],
   "source": [
    "MAX_LEN = 200\n",
    "RANDOM_SEED = 42\n",
    "#device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )\n",
    "device = torch.device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 14,
     "status": "ok",
     "timestamp": 1695497437582,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6oWORty0p8Xo",
    "outputId": "a194ad10-d038-432b-d42a-ce9580d47159"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:1\n"
     ]
    }
   ],
   "source": [
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437583,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "OtZt1p7ys7XD"
   },
   "outputs": [],
   "source": [
    "class IMDBDataset(Dataset):\n",
    "\n",
    "  def __init__(self, texts, cats1, tokenizer, max_len):\n",
    "    self.texts = texts\n",
    "    self.cats1 = cats1\n",
    "    self.tokenizer = tokenizer\n",
    "    self.max_len = max_len\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.texts)\n",
    "\n",
    "  def __getitem__(self, item):\n",
    "    text = str(self.texts[item])\n",
    "    cat1 = self.cats1[item]\n",
    "\n",
    "\n",
    "    encoding = self.tokenizer.encode_plus(\n",
    "      text,\n",
    "      add_special_tokens=True,\n",
    "      max_length=self.max_len,\n",
    "      return_token_type_ids=False,\n",
    "      padding='max_length',\n",
    "      truncation = True,\n",
    "      return_attention_mask=True,\n",
    "      return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    return {\n",
    "      'text': text,\n",
    "      'input_ids': encoding['input_ids'].flatten(),\n",
    "      'attention_mask': encoding['attention_mask'].flatten(),\n",
    "      'cat1': torch.tensor(cat1, dtype=torch.long),\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437584,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UuOujQajtL5f"
   },
   "outputs": [],
   "source": [
    "def create_data_loader(df, tokenizer, max_len, batch_size):\n",
    "  ds = IMDBDataset(\n",
    "    texts=df.Text.to_numpy(),\n",
    "    cats1=df['Cat1'].to_numpy(),\n",
    "    tokenizer=tokenizer,\n",
    "    max_len=max_len\n",
    "  )\n",
    "\n",
    "  return DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=8\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695497437585,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "3zzA4eBytOqj"
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "\n",
    "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 448,
     "status": "ok",
     "timestamp": 1695497438021,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AoUfRPy0tQgk",
    "outputId": "1522a9b5-73c6-4669-eb4d-2c71930b9a1a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['text', 'input_ids', 'attention_mask', 'cat1'])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = next(iter(train_data_loader))\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1695497438022,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Y3Nil-yatUqF"
   },
   "outputs": [],
   "source": [
    "class IMDBClassifier(nn.Module):\n",
    "  def __init__(self, n_classes):\n",
    "    super(IMDBClassifier, self).__init__()\n",
    "    self.bert = T5EncoderModel.from_pretrained(PRE_TRAINED_MODEL_NAME,config=config)\n",
    "\n",
    "    self.FC = nn.Linear(config.hidden_size,6, bias=False)\n",
    "\n",
    "\n",
    "  def forward(self, input_ids, attention_mask):\n",
    "    with torch.no_grad():\n",
    "      pooled_output = self.bert(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_dict = False\n",
    "      )\n",
    "    pooled_output = torch.mean(pooled_output[0], dim=1) # Taking Averge pooled last layer embedding\n",
    "\n",
    "    binary_out = self.FC(pooled_output)\n",
    "\n",
    "    return binary_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "executionInfo": {
     "elapsed": 8266,
     "status": "ok",
     "timestamp": 1695497485319,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWZ37gsztWzL"
   },
   "outputs": [],
   "source": [
    "model = IMDBClassifier(6)\n",
    "model = model.to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695497485321,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWy57v2CxxCM"
   },
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    if name.startswith('bert'):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695497486649,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e5iLu13CYlux"
   },
   "outputs": [],
   "source": [
    "#for name, param in model.named_parameters():\n",
    "#    print(name, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695497487362,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "nZpqz6yDtYZ4",
    "outputId": "c35b8c5d-43f7-49a3-ff8d-a4172979d24e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 200])\n",
      "torch.Size([32, 200])\n"
     ]
    }
   ],
   "source": [
    "input_ids = data['input_ids'].to(device)\n",
    "attention_mask = data['attention_mask'].to(device)\n",
    "\n",
    "\n",
    "print(input_ids.shape) # batch size x seq length\n",
    "print(attention_mask.shape) # batch size x seq length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1695497487363,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NwvA-zp7vqc1"
   },
   "outputs": [],
   "source": [
    "#del test\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "executionInfo": {
     "elapsed": 3357,
     "status": "ok",
     "timestamp": 1695497500843,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "-9Z37OXOtb0q"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0249,  0.0193,  0.0329, -0.1503,  0.1006,  0.0638],\n",
       "        [ 0.0591, -0.0151, -0.0414, -0.0412,  0.0110,  0.0158],\n",
       "        [ 0.0961,  0.0309,  0.0557, -0.1691,  0.0725,  0.0697],\n",
       "        [ 0.0705,  0.0090,  0.0810, -0.2398,  0.1624,  0.0076],\n",
       "        [ 0.0837, -0.0442,  0.0184, -0.1942,  0.1155, -0.0371],\n",
       "        [ 0.0668,  0.0120,  0.0150, -0.1447,  0.0347, -0.0112],\n",
       "        [ 0.0523, -0.0279, -0.0150, -0.0679,  0.0882,  0.0108],\n",
       "        [ 0.1011,  0.0453,  0.1575, -0.1325,  0.1529,  0.0866],\n",
       "        [ 0.1294,  0.0150,  0.1329, -0.1760,  0.1204,  0.1360],\n",
       "        [ 0.0766,  0.0692,  0.1277, -0.1734,  0.1420,  0.0151],\n",
       "        [ 0.0326,  0.0131, -0.0589, -0.1034,  0.0782,  0.0776],\n",
       "        [ 0.1268,  0.0259,  0.0868, -0.1840,  0.1307,  0.0224],\n",
       "        [ 0.0893,  0.0555, -0.0041, -0.1224,  0.1298,  0.0171],\n",
       "        [ 0.0446,  0.0081,  0.0749, -0.2381,  0.0906,  0.0702],\n",
       "        [ 0.0317,  0.0408,  0.0046, -0.0554,  0.0434,  0.0122],\n",
       "        [ 0.0358,  0.0134, -0.0512,  0.0043,  0.0241,  0.0210],\n",
       "        [ 0.0583, -0.0089, -0.0068, -0.0853,  0.1082,  0.0389],\n",
       "        [ 0.0759, -0.0328,  0.0651, -0.1620,  0.0590,  0.0658],\n",
       "        [-0.0077,  0.0812,  0.0652, -0.1031,  0.0901,  0.0278],\n",
       "        [ 0.0228,  0.0152, -0.0545, -0.0443,  0.0267,  0.0286],\n",
       "        [ 0.0522,  0.0375,  0.0908, -0.1739,  0.0450,  0.0784],\n",
       "        [ 0.1086,  0.0462,  0.0797, -0.2031,  0.0995, -0.0021],\n",
       "        [ 0.0474,  0.0192, -0.0433, -0.0625,  0.0304,  0.0318],\n",
       "        [ 0.0762,  0.0383,  0.0703, -0.1873,  0.0851,  0.0814],\n",
       "        [ 0.0781, -0.0208, -0.0086, -0.0664,  0.0565,  0.0793],\n",
       "        [ 0.0121,  0.0139,  0.0960, -0.1316,  0.1213,  0.0432],\n",
       "        [ 0.0217, -0.0138, -0.0892, -0.0030, -0.0064, -0.0129],\n",
       "        [ 0.0700,  0.0057,  0.0105, -0.1161,  0.0929,  0.0273],\n",
       "        [ 0.1121,  0.0655, -0.0164, -0.1368,  0.1089,  0.0046],\n",
       "        [ 0.0587,  0.0482,  0.0198, -0.1236,  0.0701,  0.0402],\n",
       "        [ 0.1243,  0.0604,  0.0567, -0.1441,  0.1275,  0.0356],\n",
       "        [ 0.0267,  0.0395,  0.0019, -0.1475,  0.0866, -0.0572]],\n",
       "       device='cuda:1', grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outs = model(input_ids, attention_mask)\n",
    "outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "executionInfo": {
     "elapsed": 410,
     "status": "ok",
     "timestamp": 1695498559469,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cjMVWA5a_6lf"
   },
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=0.001)\n",
    "total_steps = len(train_data_loader) * EPOCHS\n",
    "\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "  optimizer,\n",
    "  num_warmup_steps=math.floor((1./5)*total_steps),\n",
    "  num_training_steps=total_steps\n",
    ")\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss().to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1695497553390,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cLFDb4pzbx9W"
   },
   "outputs": [],
   "source": [
    "def train_epoch(\n",
    "  model,\n",
    "  data_loader,\n",
    "  loss_fn,\n",
    "  optimizer,\n",
    "  device,\n",
    "  scheduler,\n",
    "  n_examples\n",
    "):\n",
    "  model = model.train()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  for d in data_loader:\n",
    "    input_ids = d[\"input_ids\"].to(device)\n",
    "    attention_mask = d[\"attention_mask\"].to(device)\n",
    "    cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "    outputs = model(\n",
    "      input_ids=input_ids,\n",
    "      attention_mask=attention_mask\n",
    "    ).to(device)\n",
    "\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    loss = loss_fn(outputs, cat1)\n",
    "\n",
    "    correct_predictions += torch.sum(preds == cat1)\n",
    "    losses.append(loss.item())\n",
    "\n",
    "\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 11,
     "status": "ok",
     "timestamp": 1695497553909,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "z4GAdIawtUue"
   },
   "outputs": [],
   "source": [
    "def eval_model(model, data_loader, loss_fn, device, n_examples, on_new=False):\n",
    "  model = model.eval()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      ).to(device)\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      loss = loss_fn(outputs, cat1)\n",
    "\n",
    "      correct_predictions += torch.sum(preds == cat1)\n",
    "      losses.append(loss.item())\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3729011,
     "status": "ok",
     "timestamp": 1695502293865,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "IqdIHJsrANr0",
    "outputId": "883126c6-e286-4c7a-ab04-60c2908e93dc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "----------\n",
      "Train loss 1.6205975795727645 accuracy 0.3933969769291965\n",
      "Val   loss 1.400238669874296 accuracy 0.5027741083223249\n",
      "\n",
      "Epoch 2/8\n",
      "----------\n",
      "Train loss 1.2036537059038592 accuracy 0.622149297268629\n",
      "Val   loss 1.0362986182864709 accuracy 0.6672391017173052\n",
      "\n",
      "Epoch 3/8\n",
      "----------\n",
      "Train loss 0.9555213156838686 accuracy 0.7087311058074781\n",
      "Val   loss 0.9022605610296193 accuracy 0.7083223249669749\n",
      "\n",
      "Epoch 4/8\n",
      "----------\n",
      "Train loss 0.8553834918447921 accuracy 0.7347852028639619\n",
      "Val   loss 0.8388205477457006 accuracy 0.7256274768824307\n",
      "\n",
      "Epoch 5/8\n",
      "----------\n",
      "Train loss 0.8022850884928556 accuracy 0.7478122513922036\n",
      "Val   loss 0.8067228704816681 accuracy 0.7334214002642008\n",
      "\n",
      "Epoch 6/8\n",
      "----------\n",
      "Train loss 0.7706339136793151 accuracy 0.7565632458233891\n",
      "Val   loss 0.7877737359155582 accuracy 0.7373844121532365\n",
      "\n",
      "Epoch 7/8\n",
      "----------\n",
      "Train loss 0.7567904820371407 accuracy 0.758419517369398\n",
      "Val   loss 0.7781156912634645 accuracy 0.7393659180977543\n",
      "\n",
      "Epoch 8/8\n",
      "----------\n",
      "Train loss 0.7470274420014621 accuracy 0.7597454256165473\n",
      "Val   loss 0.7744350028440419 accuracy 0.7409511228533685\n",
      "\n",
      "CPU times: user 1h 1min 32s, sys: 15.6 s, total: 1h 1min 48s\n",
      "Wall time: 1h 1min 56s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_a = []\n",
    "train_l = []\n",
    "val_a = []\n",
    "val_l = []\n",
    "best_accuracy = 0\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "  print(f'Epoch {epoch + 1}/{EPOCHS}')\n",
    "  print('-' * 10)\n",
    "\n",
    "  train_acc, train_loss = train_epoch(\n",
    "    model,\n",
    "    train_data_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device,\n",
    "    scheduler,\n",
    "    len(df_train)\n",
    "  )\n",
    "\n",
    "  print(f'Train loss {train_loss} accuracy {train_acc}')\n",
    "\n",
    "  val_acc, val_loss = eval_model(\n",
    "    model,\n",
    "    val_data_loader,\n",
    "    loss_fn,\n",
    "    device,\n",
    "    len(df_val)\n",
    "  )\n",
    "\n",
    "  print(f'Val   loss {val_loss} accuracy {val_acc}')\n",
    "  print()\n",
    "\n",
    "  train_a.append(train_acc)\n",
    "  train_l.append(train_loss)\n",
    "  val_a.append(val_acc)\n",
    "  val_l.append(val_loss)\n",
    "\n",
    "  if val_acc > best_accuracy:\n",
    "    torch.save(model.state_dict(), 'baseline_t5_best_model_state.bin')\n",
    "    best_accuracy = val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "id": "FowMSU5U7SDQ"
   },
   "outputs": [],
   "source": [
    "train_a = [i.item() for i in train_a]\n",
    "train_l = [i.item() for i in train_l]\n",
    "val_a = [i.item() for i in val_a]\n",
    "val_l = [i.item() for i in val_l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 472
    },
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1693911460382,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aUQbxyTEAPhM",
    "outputId": "c5781d52-da67-4451-9200-f67ac20b6c33"
   },
   "outputs": [],
   "source": [
    "plt.plot(train_a, label='train accuracy')\n",
    "plt.plot(val_a, label='validation accuracy')\n",
    "\n",
    "plt.title('Training history')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()\n",
    "plt.ylim([0, 1]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbT4YzHFh1s7"
   },
   "source": [
    "Accuracy of Cat1 on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 115057,
     "status": "ok",
     "timestamp": 1695498550868,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L1ZzFERMAQk9",
    "outputId": "8fc647f5-054b-4149-d279-4bef31d2bbb5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6867444549563755"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_acc, _ = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "test_acc.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "id": "d7DaFLMiAWps"
   },
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader):\n",
    "  model = model.eval()\n",
    "\n",
    "  review = []\n",
    "  predictions = []\n",
    "\n",
    "  prediction_probs = []\n",
    "  real_values = []\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "\n",
    "      texts = d[\"text\"]\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      )\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "\n",
    "      probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "      review.extend(texts)\n",
    "      predictions.extend(preds)\n",
    "\n",
    "      prediction_probs.extend(probs)\n",
    "      real_values.extend(cat1)\n",
    "\n",
    "\n",
    "  predictions = torch.stack(predictions).cpu()\n",
    "  prediction_probs = torch.stack(prediction_probs).cpu()\n",
    "  real_values = torch.stack(real_values).cpu()\n",
    "\n",
    "  return review, predictions, prediction_probs, real_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "id": "kZhYtki1AYx8"
   },
   "outputs": [],
   "source": [
    "y_review_texts, y_pred, y_pred_probs, y_test= get_predictions(\n",
    "  model,\n",
    "  test_data_loader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "id": "mO1FHn0GAbCU"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PnBSQFJuh_As"
   },
   "source": [
    "Cat1 Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 650,
     "status": "ok",
     "timestamp": 1693917674551,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "J9-01snpAcM8",
    "outputId": "9540fd2f-e127-46dc-eea1-a5109958ddcf",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                      precision    recall  f1-score   support\n",
      "\n",
      "        pet supplies       0.79      0.62      0.69      1576\n",
      "              beauty       0.77      0.65      0.71      2027\n",
      "          toys games       0.66      0.83      0.73      1533\n",
      "grocery gourmet food       0.67      0.65      0.66       811\n",
      "health personal care       0.67      0.72      0.70      2936\n",
      "       baby products       0.45      0.51      0.48       630\n",
      "\n",
      "            accuracy                           0.69      9513\n",
      "           macro avg       0.67      0.66      0.66      9513\n",
      "        weighted avg       0.70      0.69      0.69      9513\n",
      "\n"
     ]
    }
   ],
   "source": [
    "class_names = list(cat1_map.keys())\n",
    "print(classification_report(y_test, y_pred, target_names=class_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
