{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 29034,
     "status": "ok",
     "timestamp": 1695646934854,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "F7toc08bpdQ1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "import textwrap\n",
    "import math\n",
    "from sklearn.model_selection import train_test_split\n",
    "from IPython.display import clear_output\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from transformers import RobertaModel, RobertaConfig, RobertaTokenizer\n",
    "\n",
    "\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 622,
     "status": "ok",
     "timestamp": 1695647243491,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "MLk4JWizs4S9",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/train_40k.csv')\n",
    "df_test = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/val_10k.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 19,
     "status": "ok",
     "timestamp": 1695647244030,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "7VLaSHYyzg58",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695647244031,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NqwYxgPz1I-L",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_train = df_train[['Text','Cat1','Cat2','Cat3']]\n",
    "df_val = df_val[['Text','Cat1','Cat2','Cat3']]\n",
    "df_test = df_test[['Text','Cat1','Cat2','Cat3']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1695647244826,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e0NG9J-oq_wW",
    "outputId": "96182a88-b8f5-4426-df01-4e5375709d99",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>beauty</td>\n",
       "      <td>hair care</td>\n",
       "      <td>styling tools</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>novelty gag toys</td>\n",
       "      <td>miniatures</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>grocery gourmet food</td>\n",
       "      <td>fresh flowers live indoor plants</td>\n",
       "      <td>live indoor plants</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24990</th>\n",
       "      <td>Watchmaker offered to install a new battery in...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>household supplies</td>\n",
       "      <td>unknown</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    Text  \\\n",
       "14307  The concept of this toy is good. However, if y...   \n",
       "17812  This dryer ruined my hair!!! At first, after I...   \n",
       "11020  Much to my surprise after a year of waiting th...   \n",
       "15158  The tree is beautiful but upon arrival when I ...   \n",
       "24990  Watchmaker offered to install a new battery in...   \n",
       "\n",
       "                       Cat1                              Cat2  \\\n",
       "14307          pet supplies                              dogs   \n",
       "17812                beauty                         hair care   \n",
       "11020            toys games                  novelty gag toys   \n",
       "15158  grocery gourmet food  fresh flowers live indoor plants   \n",
       "24990  health personal care                household supplies   \n",
       "\n",
       "                     Cat3  \n",
       "14307                toys  \n",
       "17812       styling tools  \n",
       "11020          miniatures  \n",
       "15158  live indoor plants  \n",
       "24990             unknown  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695647245529,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "ARIavVExPA5X",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_test = df_test[df_test['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n",
    "df_train = df_train[df_train['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n",
    "df_val = df_val[df_val['Cat3'] != 'unknown']   #Dropping \"unknown\" rows\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1695647245530,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L5ofeF41UC3N",
    "outputId": "349fbefa-9f9c-46db-bed2-dabeebf6ba6a",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>beauty</td>\n",
       "      <td>hair care</td>\n",
       "      <td>styling tools</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>novelty gag toys</td>\n",
       "      <td>miniatures</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>grocery gourmet food</td>\n",
       "      <td>fresh flowers live indoor plants</td>\n",
       "      <td>live indoor plants</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5980</th>\n",
       "      <td>HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...</td>\n",
       "      <td>toys games</td>\n",
       "      <td>action toy figures</td>\n",
       "      <td>playsets</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>Stays on continuously without shutting off! It...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>health care</td>\n",
       "      <td>pain relievers</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>these look great in our 10 gallon tank- colors...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>fish aquatic pets</td>\n",
       "      <td>aquarium d cor</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>This works great, but needs a better way to at...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>carriers travel products</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>she absolutely LOVES this thing. I dice up gre...</td>\n",
       "      <td>pet supplies</td>\n",
       "      <td>dogs</td>\n",
       "      <td>toys</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>I hurt my neck and went to rehab. They had me ...</td>\n",
       "      <td>health personal care</td>\n",
       "      <td>medical supplies equipment</td>\n",
       "      <td>occupational physical therapy aids</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>47251 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    Text  \\\n",
       "14307  The concept of this toy is good. However, if y...   \n",
       "17812  This dryer ruined my hair!!! At first, after I...   \n",
       "11020  Much to my surprise after a year of waiting th...   \n",
       "15158  The tree is beautiful but upon arrival when I ...   \n",
       "5980   HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...   \n",
       "...                                                  ...   \n",
       "9995   Stays on continuously without shutting off! It...   \n",
       "9996   these look great in our 10 gallon tank- colors...   \n",
       "9997   This works great, but needs a better way to at...   \n",
       "9998   she absolutely LOVES this thing. I dice up gre...   \n",
       "9999   I hurt my neck and went to rehab. They had me ...   \n",
       "\n",
       "                       Cat1                              Cat2  \\\n",
       "14307          pet supplies                              dogs   \n",
       "17812                beauty                         hair care   \n",
       "11020            toys games                  novelty gag toys   \n",
       "15158  grocery gourmet food  fresh flowers live indoor plants   \n",
       "5980             toys games                action toy figures   \n",
       "...                     ...                               ...   \n",
       "9995   health personal care                       health care   \n",
       "9996           pet supplies                 fish aquatic pets   \n",
       "9997           pet supplies                              dogs   \n",
       "9998           pet supplies                              dogs   \n",
       "9999   health personal care        medical supplies equipment   \n",
       "\n",
       "                                     Cat3  \n",
       "14307                                toys  \n",
       "17812                       styling tools  \n",
       "11020                          miniatures  \n",
       "15158                  live indoor plants  \n",
       "5980                             playsets  \n",
       "...                                   ...  \n",
       "9995                       pain relievers  \n",
       "9996                       aquarium d cor  \n",
       "9997             carriers travel products  \n",
       "9998                                 toys  \n",
       "9999   occupational physical therapy aids  \n",
       "\n",
       "[47251 rows x 4 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.concat([df_train, df_val, df_test], axis=0)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695647249805,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "LYX4e7S_HPUl",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Label Encode Cat3\n",
    "df['Cat3-map'], map = pd.factorize(df['Cat3'])\n",
    "cat3_map = dict(zip(map, range(len(map))))\n",
    "map_cat3 = {v: k for k, v in cat3_map.items()}\n",
    "\n",
    "df_train['Cat3'] = df_train[\"Cat3\"].apply(lambda x: cat3_map[x])\n",
    "df_val['Cat3'] = df_val[\"Cat3\"].apply(lambda x: cat3_map[x])\n",
    "df_test['Cat3'] = df_test[\"Cat3\"].apply(lambda x: cat3_map[x])\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695647250453,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "ncW9QxKlVJ1Q",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Label Encode Cat2\n",
    "df['Cat2-map'], map = pd.factorize(df['Cat2'])\n",
    "cat2_map = dict(zip(map, range(len(map))))\n",
    "map_cat2 = {v: k for k, v in cat2_map.items()}\n",
    "\n",
    "df_train['Cat2'] = df_train[\"Cat2\"].apply(lambda x: cat2_map[x])\n",
    "df_val['Cat2'] = df_val[\"Cat2\"].apply(lambda x: cat2_map[x])\n",
    "df_test['Cat2'] = df_test[\"Cat2\"].apply(lambda x: cat2_map[x])\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695647250454,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "5aI7RXTXWdaf",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Label Encode Cat1\n",
    "df['Cat1-map'], map = pd.factorize(df['Cat1'])\n",
    "cat1_map = dict(zip(map, range(len(map))))\n",
    "map_cat1 = {v: k for k, v in cat1_map.items()}\n",
    "\n",
    "df_train['Cat1'] = df_train[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "df_val['Cat1'] = df_val[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "df_test['Cat1'] = df_test[\"Cat1\"].apply(lambda x: cat1_map[x])\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "executionInfo": {
     "elapsed": 11,
     "status": "ok",
     "timestamp": 1695647252282,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6tG-vU5kVhiT",
    "outputId": "83ea1cce-cf77-4cf8-9c74-b3682363fb04",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Cat1</th>\n",
       "      <th>Cat2</th>\n",
       "      <th>Cat3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14307</th>\n",
       "      <td>The concept of this toy is good. However, if y...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17812</th>\n",
       "      <td>This dryer ruined my hair!!! At first, after I...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11020</th>\n",
       "      <td>Much to my surprise after a year of waiting th...</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15158</th>\n",
       "      <td>The tree is beautiful but upon arrival when I ...</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5980</th>\n",
       "      <td>HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6265</th>\n",
       "      <td>This pump worked for about a week. It kept get...</td>\n",
       "      <td>5</td>\n",
       "      <td>9</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11284</th>\n",
       "      <td>The only problem I have with this product is t...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38158</th>\n",
       "      <td>I've had this for only a couple of weeks, but ...</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>860</th>\n",
       "      <td>My son loved this toy bar on his stroller. The...</td>\n",
       "      <td>2</td>\n",
       "      <td>17</td>\n",
       "      <td>255</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15795</th>\n",
       "      <td>I used to use only Johnson &amp; Johnson Easy Acce...</td>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>47</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>30168 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    Text  Cat1  Cat2  Cat3\n",
       "14307  The concept of this toy is good. However, if y...     0     0     0\n",
       "17812  This dryer ruined my hair!!! At first, after I...     1     1     1\n",
       "11020  Much to my surprise after a year of waiting th...     2     2     2\n",
       "15158  The tree is beautiful but upon arrival when I ...     3     3     3\n",
       "5980   HI MY NAME IS SHARON AND I JUST LOVE IT!!!!!!!...     2     4     4\n",
       "...                                                  ...   ...   ...   ...\n",
       "6265   This pump worked for about a week. It kept get...     5     9    10\n",
       "11284  The only problem I have with this product is t...     0     0   112\n",
       "38158  I've had this for only a couple of weeks, but ...     0     5     6\n",
       "860    My son loved this toy bar on his stroller. The...     2    17   255\n",
       "15795  I used to use only Johnson & Johnson Easy Acce...     4     6    47\n",
       "\n",
       "[30168 rows x 4 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695647252830,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "qt56Ke3MrXDs",
    "tags": []
   },
   "outputs": [],
   "source": [
    "PRE_TRAINED_MODEL_NAME = 'roberta-large'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "executionInfo": {
     "elapsed": 978,
     "status": "ok",
     "timestamp": 1695647254371,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "xG4nYgPRrY2e",
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer = RobertaTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "config = RobertaConfig.from_pretrained(PRE_TRAINED_MODEL_NAME)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1695647254372,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UmEtmaFNrakz",
    "tags": []
   },
   "outputs": [],
   "source": [
    "MAX_LEN = 200\n",
    "RANDOM_SEED = 42\n",
    "#device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )\n",
    "device = torch.device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695647255297,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "6oWORty0p8Xo",
    "outputId": "c092d6c4-8082-4524-ad6b-c6d71e4a345d",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 625,
     "referenced_widgets": [
      "371030e0df764b32b41dece67af8774a",
      "7a097b6b084f4fcdb80162b2eaf0318a",
      "f949ddf7ee5748328238adcb4e6be853",
      "8d1b4a1ce3f44d698c1d7873917091b5",
      "1c904dd6194b45ad80b5350554be7783",
      "fee8eb44bf6a452ab5b83eb74c76af41",
      "bf31fedc33724b2f98b8679ddc4f1f03",
      "bbd49f0e4a91459fa34d6d5cef8771df",
      "02c5d9c00c8146319f3bdbd2efb850e9",
      "bb66847397264ee3989babdc4ce0fa19",
      "26c2b2a6356f4c20828c4aa654b06187"
     ]
    },
    "executionInfo": {
     "elapsed": 4342,
     "status": "ok",
     "timestamp": 1695647260773,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "yDSL15DQGvrh",
    "outputId": "b3eac964-a1e6-4dd6-f86b-3277ef89cfc9",
    "tags": []
   },
   "outputs": [],
   "source": [
    "pre_trained_model = RobertaModel.from_pretrained(PRE_TRAINED_MODEL_NAME).to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1695647260774,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "R2PbApHJGstz",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_cls(sent):\n",
    "  encoded_review = tokenizer.encode_plus(\n",
    "  sent,\n",
    "  max_length=MAX_LEN,\n",
    "  add_special_tokens=True,\n",
    "  return_token_type_ids=False,\n",
    "  truncation = True,\n",
    "  padding='max_length',\n",
    "  return_attention_mask=True,\n",
    "  return_tensors='pt',\n",
    "  )\n",
    "  input_ids = encoded_review['input_ids'].to(device)\n",
    "  attention_mask = encoded_review['attention_mask'].to(device)\n",
    "\n",
    "  pre_trained_model.eval()\n",
    "  with torch.no_grad():\n",
    "    output = pre_trained_model(input_ids, attention_mask)\n",
    "    return torch.mean(output[0], dim=1)[0] #Output Avg pooled vector embedding of last layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695647546534,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "DscmTjDNXY97",
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_cat3 = pd.read_csv('/home/m_nsu/ICLR/Datasets/Amazon/cat3.csv')\n",
    "df_cat3['Words'] = df_cat3[\"Words\"].apply(lambda x: cat3_map[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "executionInfo": {
     "elapsed": 27,
     "status": "ok",
     "timestamp": 1695647547233,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AyxIm5FPhdoC",
    "tags": []
   },
   "outputs": [],
   "source": [
    "label_cop_def = []\n",
    "for i in map_cat1.keys():\n",
    "  label_cop_def.append((i,list(df.loc[df['Cat1-map'] == i]['Cat3-map'].unique())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "executionInfo": {
     "elapsed": 26,
     "status": "ok",
     "timestamp": 1695647547237,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "XSsqe5yQkNlJ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_cop_def(a_list):\n",
    "  def_list = []\n",
    "  for i in a_list:\n",
    "    def_list.append(df_cat3[df_cat3['Words']==i]['Definitions'].item())\n",
    "  return def_list\n",
    "\n",
    "def get_cop_para_def(a_list):\n",
    "  def_list = []\n",
    "  for i in a_list:\n",
    "    def_list.append(df_cat3[df_cat3['Words']==i]['Paraphrased Definitions'].item().strip())\n",
    "  return def_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "executionInfo": {
     "elapsed": 711,
     "status": "ok",
     "timestamp": 1695647548484,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "crn5ud4miTob",
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i in range(len(label_cop_def)):\n",
    "  def_list = get_cop_def(label_cop_def[i][1])\n",
    "  label_cop_def[i] = (*label_cop_def[i], def_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1695647548485,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "vNvqE1l4dQl_",
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i in range(len(label_cop_def)):\n",
    "  def_list = get_cop_para_def(label_cop_def[i][1])\n",
    "  label_cop_def[i] = (*label_cop_def[i], def_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "executionInfo": {
     "elapsed": 11779,
     "status": "ok",
     "timestamp": 1695647580641,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "LNSmtyIyoOUv",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def make_emb(definition_list):\n",
    "  emb = torch.tensor([[1.]*config.hidden_size]).to(device)\n",
    "  for i in definition_list:\n",
    "    emb = torch.cat((emb,get_cls(i).unsqueeze(dim=0)), dim=0)\n",
    "  return emb[1:]\n",
    "\n",
    "\n",
    "zero_emb = make_emb(label_cop_def[0][2])\n",
    "one_emb = make_emb(label_cop_def[1][2])\n",
    "two_emb = make_emb(label_cop_def[2][2])\n",
    "three_emb = make_emb(label_cop_def[3][2])\n",
    "four_emb = make_emb(label_cop_def[4][2])\n",
    "five_emb = make_emb(label_cop_def[5][2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 32,
     "status": "ok",
     "timestamp": 1695647580642,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "V_qXeXCGgjtq",
    "outputId": "024bd15b-dc54-4763-98e1-13258644e52e",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([40, 1024])\n",
      "torch.Size([36, 1024])\n",
      "torch.Size([139, 1024])\n",
      "torch.Size([136, 1024])\n",
      "torch.Size([57, 1024])\n",
      "torch.Size([67, 1024])\n"
     ]
    }
   ],
   "source": [
    "print(zero_emb.shape)\n",
    "print(one_emb.shape)\n",
    "print(two_emb.shape)\n",
    "print(three_emb.shape)\n",
    "print(four_emb.shape)\n",
    "print(five_emb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "executionInfo": {
     "elapsed": 23,
     "status": "ok",
     "timestamp": 1695647580643,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "OtZt1p7ys7XD",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class IMDBDataset(Dataset):\n",
    "\n",
    "  def __init__(self, texts, cats1, cats3, tokenizer, max_len):\n",
    "    self.texts = texts\n",
    "    self.cats1 = cats1\n",
    "    self.cats3 = cats3\n",
    "    self.tokenizer = tokenizer\n",
    "    self.max_len = max_len\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.texts)\n",
    "\n",
    "  def __getitem__(self, item):\n",
    "    text = str(self.texts[item])\n",
    "    cat1 = self.cats1[item]\n",
    "    cat3 = self.cats3[item]\n",
    "\n",
    "\n",
    "    encoding = self.tokenizer.encode_plus(\n",
    "      text,\n",
    "      add_special_tokens=True,\n",
    "      max_length=self.max_len,\n",
    "      return_token_type_ids=False,\n",
    "      padding='max_length',\n",
    "      truncation = True,\n",
    "      return_attention_mask=True,\n",
    "      return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    return {\n",
    "      'text': text,\n",
    "      'input_ids': encoding['input_ids'].flatten(),\n",
    "      'attention_mask': encoding['attention_mask'].flatten(),\n",
    "      'cat1': torch.tensor(cat1, dtype=torch.long),\n",
    "      'cat3': torch.tensor(cat3, dtype=torch.long),\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1695647580643,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "UuOujQajtL5f",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def create_data_loader(df, tokenizer, max_len, batch_size):\n",
    "  ds = IMDBDataset(\n",
    "    texts=df.Text.to_numpy(),\n",
    "    cats1=df['Cat1'].to_numpy(),\n",
    "    cats3=df['Cat3'].to_numpy(),\n",
    "    tokenizer=tokenizer,\n",
    "    max_len=max_len\n",
    "  )\n",
    "\n",
    "  return DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=8\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1695647580644,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "3zzA4eBytOqj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "\n",
    "train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)\n",
    "test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1695647580644,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "AoUfRPy0tQgk",
    "outputId": "0882aedd-d5f6-4b6f-e03a-c994018a52dc",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['text', 'input_ids', 'attention_mask', 'cat1', 'cat3'])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = next(iter(train_data_loader))\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "executionInfo": {
     "elapsed": 14,
     "status": "ok",
     "timestamp": 1695647580644,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "Y3Nil-yatUqF",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class IMDBClassifier(nn.Module):\n",
    "  def __init__(self, n_classes):\n",
    "    super(IMDBClassifier, self).__init__()\n",
    "    self.bert = RobertaModel.from_pretrained(PRE_TRAINED_MODEL_NAME,config=config)\n",
    "\n",
    "    self.FC = nn.Linear(config.hidden_size,config.hidden_size, bias=False)\n",
    "\n",
    "  def CosineNorm(self, c, b, n_words):\n",
    "    cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-8)\n",
    "    simi = torch.tensor([[1.]*n_words]).to(device)\n",
    "    for i in c:\n",
    "      temp = cos(i,b).to(device)\n",
    "      temp = torch.unsqueeze(temp,0)\n",
    "      simi = torch.cat((simi,temp), dim=0)\n",
    "\n",
    "    return simi[1:]\n",
    "\n",
    "  def binary_output(self,zero, one, two, three, four, five):\n",
    "    zero = torch.unsqueeze(torch.max(zero,1).values,0)\n",
    "    one = torch.unsqueeze(torch.max(one,1).values,0)\n",
    "    two = torch.unsqueeze(torch.max(two,1).values,0)\n",
    "    three = torch.unsqueeze(torch.max(three,1).values,0)\n",
    "    four = torch.unsqueeze(torch.max(four,1).values,0)\n",
    "    five = torch.unsqueeze(torch.max(five,1).values,0)\n",
    "\n",
    "    res = torch.cat((zero,one,two,three,four,five), dim=0)\n",
    "\n",
    "    return torch.t(res)\n",
    "\n",
    "  def forward(self, input_ids, attention_mask, return_scores=False):\n",
    "    with torch.no_grad():\n",
    "      pooled_output = self.bert(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_dict = False\n",
    "      )\n",
    "    pooled_output = torch.mean(pooled_output[0], dim=1) # Taking Averge pooled last layer embedding\n",
    "\n",
    "    x_sent = self.FC(pooled_output)\n",
    "\n",
    "    zeroT = self.FC(zero_emb)\n",
    "    oneT = self.FC(one_emb)\n",
    "    twoT = self.FC(two_emb)\n",
    "    threeT = self.FC(three_emb)\n",
    "    fourT = self.FC(four_emb)\n",
    "    fiveT = self.FC(five_emb)\n",
    "\n",
    "\n",
    "    zero = F.relu(self.CosineNorm(x_sent, zeroT, zero_emb.size()[0]))\n",
    "    one = F.relu(self.CosineNorm(x_sent, oneT, one_emb.size()[0]))\n",
    "    two = F.relu(self.CosineNorm(x_sent, twoT, two_emb.size()[0]))\n",
    "    three = F.relu(self.CosineNorm(x_sent, threeT, three_emb.size()[0]))\n",
    "    four = F.relu(self.CosineNorm(x_sent, fourT, four_emb.size()[0]))\n",
    "    five = F.relu(self.CosineNorm(x_sent, fiveT, five_emb.size()[0]))\n",
    "\n",
    "\n",
    "    binary_out = self.binary_output(zero,one,two,three,four,five)\n",
    "    if(return_scores==True):\n",
    "      return (zero,one,two,three,four,five) , binary_out\n",
    "    return binary_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "executionInfo": {
     "elapsed": 1533,
     "status": "ok",
     "timestamp": 1695647582164,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWZ37gsztWzL",
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = IMDBClassifier(6)\n",
    "model = model.to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "executionInfo": {
     "elapsed": 29,
     "status": "ok",
     "timestamp": 1695647582165,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "HWy57v2CxxCM",
    "tags": []
   },
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    if name.startswith('bert'):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "executionInfo": {
     "elapsed": 28,
     "status": "ok",
     "timestamp": 1695647582165,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "e5iLu13CYlux",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#for name, param in model.named_parameters():\n",
    "#    print(name, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 27,
     "status": "ok",
     "timestamp": 1695647582165,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "nZpqz6yDtYZ4",
    "outputId": "ece9dd17-1c93-4ef8-e271-148c41c590d0",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 200])\n",
      "torch.Size([32, 200])\n"
     ]
    }
   ],
   "source": [
    "input_ids = data['input_ids'].to(device)\n",
    "attention_mask = data['attention_mask'].to(device)\n",
    "\n",
    "\n",
    "print(input_ids.shape) # batch size x seq length\n",
    "print(attention_mask.shape) # batch size x seq length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "executionInfo": {
     "elapsed": 25,
     "status": "ok",
     "timestamp": 1695647582168,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "NwvA-zp7vqc1",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#del test\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695647582168,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "-9Z37OXOtb0q",
    "tags": []
   },
   "outputs": [],
   "source": [
    "s, outs = model(input_ids, attention_mask, return_scores=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1695647582169,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cjMVWA5a_6lf",
    "tags": []
   },
   "outputs": [],
   "source": [
    "EPOCHS = 8\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=0.001)\n",
    "total_steps = len(train_data_loader) * EPOCHS\n",
    "\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "  optimizer,\n",
    "  num_warmup_steps=math.floor((1./5)*total_steps),\n",
    "  num_training_steps=total_steps\n",
    ")\n",
    "\n",
    "loss_fn = torch.nn.CrossEntropyLoss().to(device)\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "executionInfo": {
     "elapsed": 19,
     "status": "ok",
     "timestamp": 1695647582169,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "cLFDb4pzbx9W",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_epoch(\n",
    "  model,\n",
    "  data_loader,\n",
    "  loss_fn,\n",
    "  optimizer,\n",
    "  device,\n",
    "  scheduler,\n",
    "  n_examples\n",
    "):\n",
    "  model = model.train()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  for d in data_loader:\n",
    "    input_ids = d[\"input_ids\"].to(device)\n",
    "    attention_mask = d[\"attention_mask\"].to(device)\n",
    "    cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "    outputs = model(\n",
    "      input_ids=input_ids,\n",
    "      attention_mask=attention_mask\n",
    "    ).to(device)\n",
    "\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    loss = loss_fn(outputs, cat1)\n",
    "\n",
    "    correct_predictions += torch.sum(preds == cat1)\n",
    "    losses.append(loss.item())\n",
    "\n",
    "\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695647582169,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "z4GAdIawtUue",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def eval_model(model, data_loader, loss_fn, device, n_examples, on_new=False):\n",
    "  model = model.eval()\n",
    "\n",
    "  losses = []\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat1 = d[\"cat1\"].to(device)\n",
    "\n",
    "      outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "      ).to(device)\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      loss = loss_fn(outputs, cat1)\n",
    "\n",
    "      correct_predictions += torch.sum(preds == cat1)\n",
    "      losses.append(loss.item())\n",
    "\n",
    "  return correct_predictions.double() / n_examples, np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1695647582170,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "bSVNs5ZMpIsj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def max_emotion(zero,one,two,three,four,five):\n",
    "  max_a_values = torch.stack((torch.max(zero,1).values, torch.max(one,1).values, torch.max(two,1).values , torch.max(three,1).values, torch.max(four,1).values, torch.max(five,1).values))\n",
    "\n",
    "  max_a_indices = torch.stack((torch.max(zero,1).indices, torch.max(one,1).indices, torch.max(two,1).indices, torch.max(three,1).indices, torch.max(four,1).indices, torch.max(five,1).indices))\n",
    "\n",
    "  agg_max = torch.max(max_a_values,0)\n",
    "  emotions_indices = [i[1][agg_max.indices[i[0]]] for i in enumerate(max_a_indices.t())]\n",
    "  emotions_indices = torch.stack(emotions_indices)\n",
    "\n",
    "  emotion_list = []\n",
    "  for i in range(len(emotions_indices)):\n",
    "    emotion_list.append(label_cop_def[agg_max.indices[i]][1][emotions_indices[i]])\n",
    "  return torch.tensor(emotion_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "executionInfo": {
     "elapsed": 17,
     "status": "ok",
     "timestamp": 1695647582170,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "E3rc-qAmeZiM",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def eval_emotions(model, data_loader, loss_fn, device, n_examples):\n",
    "  model = model.eval()\n",
    "  correct_predictions = 0\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat3 = d[\"cat3\"].to(device)\n",
    "\n",
    "      (zero,one,two,three,four,five), _ = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_scores = True\n",
    "      )\n",
    "\n",
    "      preds = max_emotion(zero,one,two,three,four,five).to(device)\n",
    "      correct_predictions += torch.sum(preds == cat3)\n",
    "\n",
    "  return correct_predictions.double() / n_examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12899,
     "status": "ok",
     "timestamp": 1695647595052,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "IqdIHJsrANr0",
    "outputId": "ef188dc8-0aa1-418a-f645-482dd7280d8c",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "----------\n",
      "Train loss 1.405366555504177 accuracy 0.7334924423229913\n",
      "Val   loss 1.3264397448125267 accuracy 0.8095112285336856\n",
      "\n",
      "Epoch 2/8\n",
      "----------\n",
      "Train loss 1.2767546377910157 accuracy 0.8216321930522408\n",
      "Val   loss 1.2819424032661986 accuracy 0.820343461030383\n",
      "\n",
      "Epoch 3/8\n",
      "----------\n",
      "Train loss 1.259902387509796 accuracy 0.8279965526385574\n",
      "Val   loss 1.2870193745013532 accuracy 0.818229854689564\n",
      "\n",
      "Epoch 4/8\n",
      "----------\n",
      "Train loss 1.2523400793652013 accuracy 0.8304826306019624\n",
      "Val   loss 1.2753481472594828 accuracy 0.8244385733157199\n",
      "\n",
      "Epoch 5/8\n",
      "----------\n",
      "Train loss 1.2455537882993988 accuracy 0.8351233094669849\n",
      "Val   loss 1.269855693925785 accuracy 0.8261558784676354\n",
      "\n",
      "Epoch 6/8\n",
      "----------\n",
      "Train loss 1.240777533860879 accuracy 0.8384049323786794\n",
      "Val   loss 1.267178261330359 accuracy 0.8281373844121532\n",
      "\n",
      "Epoch 7/8\n",
      "----------\n",
      "Train loss 1.2379528438299103 accuracy 0.8412224874038717\n",
      "Val   loss 1.2677519985392123 accuracy 0.8285336856010568\n",
      "\n",
      "Epoch 8/8\n",
      "----------\n",
      "Train loss 1.2341702887008197 accuracy 0.8427472818880933\n",
      "Val   loss 1.2612958376920675 accuracy 0.8310435931307794\n",
      "\n",
      "CPU times: user 4h 8min 37s, sys: 1min 11s, total: 4h 9min 48s\n",
      "Wall time: 4h 9min 57s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_a = []\n",
    "train_l = []\n",
    "val_a = []\n",
    "val_l = []\n",
    "best_accuracy = 0\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "  print(f'Epoch {epoch + 1}/{EPOCHS}')\n",
    "  print('-' * 10)\n",
    "\n",
    "  train_acc, train_loss = train_epoch(\n",
    "    model,\n",
    "    train_data_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device,\n",
    "    scheduler,\n",
    "    len(df_train)\n",
    "  )\n",
    "\n",
    "  print(f'Train loss {train_loss} accuracy {train_acc}')\n",
    "\n",
    "  val_acc, val_loss = eval_model(\n",
    "    model,\n",
    "    val_data_loader,\n",
    "    loss_fn,\n",
    "    device,\n",
    "    len(df_val)\n",
    "  )\n",
    "\n",
    "  print(f'Val   loss {val_loss} accuracy {val_acc}')\n",
    "  print()\n",
    "\n",
    "  train_a.append(train_acc)\n",
    "  train_l.append(train_loss)\n",
    "  val_a.append(val_acc)\n",
    "  val_l.append(val_loss)\n",
    "\n",
    "  if val_acc > best_accuracy:\n",
    "    torch.save(model.state_dict(), 'roberta_best_model_state.bin')\n",
    "    best_accuracy = val_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1695647595053,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "FowMSU5U7SDQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_a = [i.item() for i in train_a]\n",
    "train_l = [i.item() for i in train_l]\n",
    "val_a = [i.item() for i in val_a]\n",
    "val_l = [i.item() for i in val_l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 472
    },
    "executionInfo": {
     "elapsed": 860,
     "status": "ok",
     "timestamp": 1695647595906,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aUQbxyTEAPhM",
    "outputId": "411eeaa8-5597-4197-e2fb-a0a8f34e5cb8",
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.plot(train_a, label='train accuracy')\n",
    "plt.plot(val_a, label='validation accuracy')\n",
    "\n",
    "plt.title('Training history')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()\n",
    "plt.ylim([0, 1]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbT4YzHFh1s7"
   },
   "source": [
    "Accuracy of Cat1 on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1208,
     "status": "ok",
     "timestamp": 1695647597107,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "L1ZzFERMAQk9",
    "outputId": "ce282068-a74d-4ee4-e5f0-e32a1ee6d345",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7870282770945023"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_acc, _ = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "test_acc.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uSAQhl1kh5jw"
   },
   "source": [
    "Accuracy of Cat3 on Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1576,
     "status": "ok",
     "timestamp": 1695647598676,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "aw1XJc94pY9f",
    "outputId": "e90362fa-dfce-48c0-bea0-2e0f202fc74e",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.00430989172711027"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emotion_test_acc = eval_emotions(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "emotion_test_acc.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "executionInfo": {
     "elapsed": 14,
     "status": "ok",
     "timestamp": 1695647598676,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "d7DaFLMiAWps",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader):\n",
    "  model = model.eval()\n",
    "\n",
    "  review = []\n",
    "  predictions = []\n",
    "  emo_predictions = []\n",
    "\n",
    "  prediction_probs = []\n",
    "  real_values = []\n",
    "  emo_real_values = []\n",
    "\n",
    "  with torch.no_grad():\n",
    "    for d in data_loader:\n",
    "\n",
    "      texts = d[\"text\"]\n",
    "      input_ids = d[\"input_ids\"].to(device)\n",
    "      attention_mask = d[\"attention_mask\"].to(device)\n",
    "      cat1 = d[\"cat1\"].to(device)\n",
    "      cat3 = d[\"cat3\"].to(device)\n",
    "\n",
    "\n",
    "      (zero,one,two,three,four,five),outputs = model(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        return_scores=True,\n",
    "      )\n",
    "      _, preds = torch.max(outputs, dim=1)\n",
    "\n",
    "      emo_preds = max_emotion(zero,one,two,three,four,five).to(device)\n",
    "\n",
    "      probs = F.softmax(outputs, dim=1)\n",
    "\n",
    "      review.extend(texts)\n",
    "      predictions.extend(preds)\n",
    "      emo_predictions.extend(emo_preds)\n",
    "\n",
    "      prediction_probs.extend(probs)\n",
    "      real_values.extend(cat1)\n",
    "      emo_real_values.extend(cat3)\n",
    "\n",
    "\n",
    "  predictions = torch.stack(predictions).cpu()\n",
    "  emo_predictions = torch.stack(emo_predictions).cpu()\n",
    "  prediction_probs = torch.stack(prediction_probs).cpu()\n",
    "  real_values = torch.stack(real_values).cpu()\n",
    "  emo_real_values = torch.stack(emo_real_values).cpu()\n",
    "\n",
    "  return review, predictions, emo_predictions, prediction_probs, real_values, emo_real_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "executionInfo": {
     "elapsed": 1726,
     "status": "ok",
     "timestamp": 1695647600392,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "kZhYtki1AYx8",
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_review_texts, y_pred, emo_pred , y_pred_probs, y_test, emo_y_test = get_predictions(\n",
    "  model,\n",
    "  test_data_loader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "executionInfo": {
     "elapsed": 14,
     "status": "ok",
     "timestamp": 1695647600393,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "mO1FHn0GAbCU",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix\n",
    "class_names = list(cat1_map.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PnBSQFJuh_As"
   },
   "source": [
    "Cat1 Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1695647600393,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "J9-01snpAcM8",
    "outputId": "04e35297-28ce-4570-e595-abf2b1baf25f",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                      precision    recall  f1-score   support\n",
      "\n",
      "        pet supplies       0.91      0.83      0.87      1576\n",
      "              beauty       0.81      0.76      0.78      2027\n",
      "          toys games       0.81      0.84      0.83      1533\n",
      "grocery gourmet food       0.74      0.79      0.76       811\n",
      "health personal care       0.74      0.78      0.76      2936\n",
      "       baby products       0.66      0.68      0.67       630\n",
      "\n",
      "            accuracy                           0.79      9513\n",
      "           macro avg       0.78      0.78      0.78      9513\n",
      "        weighted avg       0.79      0.79      0.79      9513\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_test, y_pred, target_names=class_names))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Bjg-0Nj-iCPU"
   },
   "source": [
    "Cat3 Classification Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1695647600394,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "DHauMLMCsqPx",
    "outputId": "0b0997ef-7349-4c83-a76b-dcb9eee5f066",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                    precision    recall  f1-score   support\n",
      "\n",
      "                              toys       0.00      0.00      0.00       244\n",
      "                     styling tools       0.00      0.00      0.00       213\n",
      "                        miniatures       0.00      0.00      0.00         0\n",
      "                live indoor plants       0.00      0.00      0.00        10\n",
      "                          playsets       0.00      0.00      0.00        26\n",
      "                           figures       0.00      0.00      0.00        79\n",
      "         feeding watering supplies       0.00      0.00      0.00        87\n",
      "              shaving hair removal       0.00      0.00      0.00       280\n",
      "                       board games       0.05      0.02      0.03       124\n",
      "                         cleansers       0.00      0.00      0.00        92\n",
      "                     breastfeeding       0.00      0.00      0.00        32\n",
      "                baby gyms playmats       0.00      0.00      0.00         7\n",
      "     activity centers entertainers       0.00      0.00      0.00        11\n",
      "        packaged meals side dishes       0.00      0.00      0.00         4\n",
      "                            swings       0.00      0.00      0.00        20\n",
      "                    gates doorways       0.00      0.00      0.00        40\n",
      "                   health supplies       0.00      0.00      0.00       179\n",
      "                    jigsaw puzzles       0.00      0.00      0.00        26\n",
      "                  aquarium heaters       0.00      0.00      0.00         3\n",
      "                sports supplements       0.00      0.00      0.00        40\n",
      "                 maternity pillows       0.00      0.00      0.00        14\n",
      "                       music sound       0.00      0.00      0.00        12\n",
      "                 suckers lollipops       0.00      0.00      0.00        10\n",
      "                               sun       0.00      0.00      0.00        22\n",
      "              litter housebreaking       0.00      0.00      0.00       118\n",
      "              weight loss products       0.00      0.00      0.00        23\n",
      "              vitamins supplements       0.00      0.00      0.00       665\n",
      "                 lip care products       0.00      0.00      0.00        44\n",
      "                              face       0.00      0.00      0.00       262\n",
      "                          shampoos       0.00      0.00      0.00        53\n",
      "                              food       0.00      0.00      0.00        83\n",
      "                disposable diapers       0.00      0.00      0.00        88\n",
      "                  toaster pastries       0.00      0.00      0.00         2\n",
      "                trading card games       0.00      0.00      0.00        10\n",
      "                  doll accessories       0.00      0.00      0.00        25\n",
      "                               tea       0.00      0.00      0.00       130\n",
      "                     pumps filters       0.00      0.00      0.00        58\n",
      "                         foot care       0.00      0.00      0.00        58\n",
      "           cooking baking supplies       0.00      0.00      0.00       131\n",
      "                             balls       0.00      0.00      0.00         5\n",
      "                           women s       0.00      0.00      0.00       286\n",
      "                         cat flaps       0.11      0.11      0.11        18\n",
      "                         car seats       0.00      0.00      0.00         1\n",
      "                             herbs       0.00      0.00      0.00        45\n",
      "                          eye care       0.00      0.00      0.00        21\n",
      "              alternative medicine       0.00      0.00      0.00        74\n",
      "                 harnesses leashes       0.00      0.00      0.00         1\n",
      "                      oral hygiene       0.00      0.00      0.00       120\n",
      "                       jelly beans       0.00      0.00      0.00         3\n",
      "            training behavior aids       0.00      0.00      0.00        64\n",
      "                 daily living aids       0.00      0.00      0.00        59\n",
      "                backpacks carriers       0.00      0.00      0.00        18\n",
      "                massage relaxation       0.00      0.00      0.00        71\n",
      "                     play vehicles       0.00      0.00      0.00        22\n",
      "               household batteries       0.00      0.00      0.00       103\n",
      "                            treats       0.00      0.00      0.00       112\n",
      "                         safer sex       0.00      0.00      0.00        32\n",
      "                              eyes       0.00      0.00      0.00        71\n",
      "                           bedding       0.00      0.00      0.00        30\n",
      "                  styling products       0.00      0.00      0.00        69\n",
      "                        nail tools       0.00      0.00      0.00        17\n",
      "                   gym sets swings       0.00      0.00      0.00         7\n",
      "                     nursery d cor       0.00      0.00      0.00        14\n",
      "                              lips       0.00      0.00      0.00        40\n",
      "                           mirrors       0.00      0.00      0.00        24\n",
      "                           seafood       0.00      0.00      0.00         5\n",
      "                            braces       0.00      0.00      0.00       151\n",
      "                          diabetes       0.00      0.00      0.00        13\n",
      "                   health monitors       0.00      0.00      0.00        82\n",
      "             plug play video games       0.00      0.00      0.00         1\n",
      "              makeup brushes tools       0.00      0.00      0.00        20\n",
      "                household cleaning       0.00      0.00      0.00        85\n",
      "                   scooters wagons       0.00      0.00      0.00        51\n",
      "                hair loss products       0.00      0.00      0.00        30\n",
      "                   animals figures       0.00      0.00      0.00       139\n",
      "                      pretend play       0.00      0.00      0.00       108\n",
      "                           allergy       0.00      0.00      0.00        26\n",
      "             nutrition bars drinks       0.00      0.00      0.00       169\n",
      "                           science       0.00      0.00      0.00        69\n",
      "                   stacking blocks       0.00      0.00      0.00        23\n",
      "                bathing tubs seats       0.00      0.00      0.00         8\n",
      "                              body       0.00      0.00      0.00       130\n",
      "                          grooming       0.00      0.00      0.00       103\n",
      "                           laundry       0.00      0.00      0.00        15\n",
      "               bathing accessories       0.00      0.00      0.00        11\n",
      "                            coffee       0.00      0.00      0.00        31\n",
      "                  beds accessories       0.00      0.00      0.00        24\n",
      "                        card games       0.00      0.00      0.00        36\n",
      "                    cleaning tools       0.00      0.00      0.00        80\n",
      "                      baby formula       0.00      0.00      0.00         1\n",
      "                    pillows stools       0.00      0.00      0.00         6\n",
      "                     floor puzzles       0.00      0.00      0.00        17\n",
      "                            sports       0.00      0.00      0.00        10\n",
      "              cabinet locks straps       0.00      0.00      0.00         2\n",
      "                        craft kits       0.00      0.00      0.00        41\n",
      "        deodorants antiperspirants       0.00      0.00      0.00        27\n",
      "                         bath toys       0.00      0.00      0.00        12\n",
      "                   packaged breads       0.00      0.00      0.00         0\n",
      "                            tandem       0.00      0.00      0.00         9\n",
      "                      hard candies       0.00      0.00      0.00        18\n",
      "                             dolls       0.00      0.00      0.00        29\n",
      "                   electronic toys       0.00      0.00      0.00        28\n",
      "                           collars       0.00      0.00      0.00       103\n",
      "                          crackers       0.00      0.00      0.00        13\n",
      "                             men s       0.00      0.00      0.00       183\n",
      "                  digestion nausea       0.00      0.00      0.00        23\n",
      "                           puppets       0.00      0.00      0.00         7\n",
      "                     coconut water       0.00      0.00      0.00         8\n",
      "                       hands nails       0.00      0.00      0.00        76\n",
      "                  breakfast bakery       0.00      0.00      0.00         0\n",
      "             dollhouse accessories       0.00      0.00      0.00         5\n",
      "                     gummy candies       0.00      0.00      0.00         9\n",
      "          carriers travel products       0.00      0.00      0.00        40\n",
      "    family planning contraceptives       0.00      0.00      0.00        15\n",
      "                    beds furniture       0.00      0.00      0.00       139\n",
      "                     shape sorters       0.00      0.00      0.00        12\n",
      "              diaper pails refills       0.00      0.00      0.00         4\n",
      "          highchairs booster seats       0.00      0.00      0.00        20\n",
      "              stuffed animals toys       0.00      0.00      0.00        18\n",
      "                    push pull toys       0.00      0.00      0.00        15\n",
      "                             nails       0.00      0.00      0.00       207\n",
      "                       accessories       0.00      0.00      0.00        37\n",
      "                            sauces       0.00      0.00      0.00        25\n",
      "                        tile games       0.00      0.00      0.00         2\n",
      "                   electronic pets       0.00      0.00      0.00         6\n",
      "                        dollhouses       0.00      0.00      0.00         5\n",
      "                              jams       0.00      0.00      0.00         8\n",
      "                           statues       0.00      0.00      0.00         1\n",
      "         drawing painting supplies       0.00      0.00      0.00        23\n",
      "                   houses habitats       0.00      0.00      0.00        12\n",
      "          grooming healthcare kits       0.00      0.00      0.00         7\n",
      "                         skin care       0.00      0.00      0.00        13\n",
      "                           cookies       0.00      0.00      0.00        16\n",
      "               systems accessories       0.00      0.00      0.00        12\n",
      "                            blocks       0.00      0.00      0.00        13\n",
      "                     paper plastic       0.00      0.00      0.00        25\n",
      "           blackboards whiteboards       0.00      0.00      0.00         1\n",
      "                            easels       0.00      0.00      0.00         7\n",
      "                           puzzles       0.00      0.00      0.00         3\n",
      "           mobility aids equipment       0.00      0.00      0.00        47\n",
      "                       diaper bags       0.00      0.00      0.00         9\n",
      "                  adult toys games       0.00      0.00      0.00        49\n",
      "           rocking spring ride ons       0.00      0.00      0.00         0\n",
      "                     solid feeding       0.00      0.00      0.00        28\n",
      "                            houses       0.00      0.00      0.00        58\n",
      "                canned jarred food       0.00      0.00      0.00        14\n",
      "                           joggers       0.00      0.00      0.00         0\n",
      "                        hair color       0.00      0.00      0.00        80\n",
      "                              bars       0.00      0.00      0.00         2\n",
      "                    pain relievers       0.00      0.00      0.00       114\n",
      "                    women s health       0.00      0.00      0.00        18\n",
      "                           cereals       0.00      0.00      0.00        21\n",
      "                 spices seasonings       0.00      0.00      0.00        22\n",
      "            plush backpacks purses       0.00      0.00      0.00         1\n",
      "                             doors       0.00      0.00      0.00        20\n",
      "                  game accessories       0.00      0.00      0.00         0\n",
      "                         furniture       0.00      0.00      0.00        43\n",
      "                   facial steamers       0.00      0.00      0.00         2\n",
      "                      incontinence       0.05      0.02      0.03        49\n",
      "                  water treatments       0.00      0.00      0.00        11\n",
      "                          ear care       0.00      0.00      0.00         8\n",
      "                          licorice       0.00      0.00      0.00         8\n",
      "                  vehicle playsets       0.00      0.00      0.00        19\n",
      "                 sand water tables       0.00      0.00      0.00         0\n",
      "                edge corner guards       0.00      0.00      0.00         7\n",
      "                 smoking cessation       0.00      0.00      0.00         8\n",
      "                       seat covers       0.00      0.00      0.00         0\n",
      "                    bottle feeding       0.00      0.00      0.00        18\n",
      "              pill cases splitters       0.00      0.00      0.00        17\n",
      "                          habitats       0.00      0.00      0.00         9\n",
      "                      conditioners       0.00      0.00      0.00        49\n",
      "                       marble runs       0.00      0.00      0.00         3\n",
      "                     building sets       0.00      0.00      0.00        55\n",
      "                       snack gifts       0.00      0.00      0.00         3\n",
      "                      spices gifts       0.00      0.00      0.00         0\n",
      "                           popcorn       0.33      0.06      0.10        17\n",
      "             hair scalp treatments       0.00      0.00      0.00        18\n",
      "                   aquarium lights       0.00      0.00      0.00         8\n",
      "                         first aid       0.00      0.00      0.00        70\n",
      "                             water       0.00      0.00      0.00         1\n",
      "                  stress reduction       0.00      0.00      0.00        21\n",
      "             breakfast cereal bars       0.00      0.00      0.00        16\n",
      "                     personal care       0.00      0.00      0.00         3\n",
      "          play trains railway sets       0.00      0.00      0.00        33\n",
      "                           walkers       0.00      0.00      0.00        13\n",
      "                         aquariums       0.00      0.00      0.00         9\n",
      "                     dessert gifts       0.00      0.00      0.00         0\n",
      "                         toy balls       0.00      0.00      0.00         8\n",
      "             activity play centers       0.00      0.00      0.00         2\n",
      "                             fruit       0.00      0.00      0.00         0\n",
      "                     battling tops       0.00      0.00      0.00         1\n",
      "                            juices       0.00      0.00      0.00        18\n",
      "                     sleep snoring       0.00      0.81      0.01        16\n",
      "                   pools water fun       0.00      0.00      0.00        14\n",
      "                 die cast vehicles       0.00      0.00      0.00         1\n",
      "                           rockets       0.00      0.00      0.00         5\n",
      "               inflatable bouncers       0.00      0.00      0.00         0\n",
      "                       dishwashing       0.00      0.00      0.00        13\n",
      "                     feminine care       0.00      0.00      0.00        13\n",
      "        finger boards finger bikes       0.00      0.00      0.00         0\n",
      "                   aquarium stands       0.00      0.00      0.00         0\n",
      "             crib toys attachments       0.00      0.00      0.00         0\n",
      "            hair perms texturizers       0.00      0.00      0.00         0\n",
      "             stacking nesting toys       0.00      0.00      0.00        10\n",
      "                     radio control       0.00      0.00      0.00         8\n",
      "                         hot cocoa       0.00      0.00      0.00        12\n",
      "                          monitors       0.00      0.00      0.00         8\n",
      "                         test kits       0.00      0.00      0.00         5\n",
      "                carriers strollers       0.00      0.00      0.00         8\n",
      "                      chips crisps       0.00      0.00      0.00         4\n",
      "                           pudding       0.00      0.00      0.00         0\n",
      "        changing table pads covers       0.00      0.00      0.00         9\n",
      "              mathematics counting       0.00      0.00      0.00         5\n",
      "                   gardening tools       0.00      0.00      0.00         1\n",
      "                trains accessories       0.00      0.00      0.00        24\n",
      "                     energy drinks       0.00      0.00      0.00         4\n",
      "             music players karaoke       0.00      0.00      0.00         4\n",
      "         model building kits tools       0.00      0.00      0.00        31\n",
      "         drawing sketching tablets       0.00      0.00      0.00         3\n",
      "                              bath       0.00      0.00      0.00        13\n",
      "personal video players accessories       0.00      0.00      0.00         1\n",
      "                       health care       0.00      0.00      0.00        32\n",
      "                    outdoor safety       0.00      0.00      0.00         0\n",
      "             therapeutic skin care       0.00      0.00      0.00        14\n",
      "                         party mix       0.00      0.00      0.00         3\n",
      "                          teethers       0.00      0.00      0.00        15\n",
      "                      thermometers       0.00      0.00      0.00         6\n",
      "                     cloth diapers       0.00      0.00      0.00         2\n",
      "              bathroom aids safety       0.00      0.00      0.00        32\n",
      "                play tents tunnels       0.00      0.00      0.00         8\n",
      "             pacifiers accessories       0.00      0.00      0.00         4\n",
      "            educational repellents       0.00      0.00      0.00        14\n",
      "                    travel systems       0.00      0.00      0.00         0\n",
      "                     potties seats       0.00      0.00      0.00        10\n",
      "                         maternity       0.00      0.00      0.00         5\n",
      "               fresh baked cookies       0.00      0.00      0.00         5\n",
      "                         gift sets       0.00      0.00      0.00         0\n",
      "              shopping cart covers       0.00      0.00      0.00         0\n",
      "                beanbags foot bags       0.00      0.00      0.00         1\n",
      "                              oils       0.00      0.00      0.00         1\n",
      "                         keepsakes       0.00      0.00      0.00         3\n",
      "              diaper changing kits       0.00      0.00      0.00         1\n",
      "         ice cream frozen desserts       0.00      0.00      0.00         0\n",
      "                             taffy       0.00      0.00      0.00         6\n",
      "            magic kits accessories       0.00      0.00      0.00         9\n",
      "               apparel accessories       0.00      0.00      0.00        32\n",
      "                    aquarium d cor       0.00      0.00      0.00        21\n",
      "                         dvd games       0.00      0.00      0.00         0\n",
      "                action toy figures       0.00      0.00      0.00         3\n",
      "                            albums       0.00      0.00      0.00         2\n",
      "                       assortments       0.00      0.00      0.00         0\n",
      "                       lightweight       0.00      0.00      0.00         2\n",
      "occupational physical therapy aids       0.00      0.00      0.00        13\n",
      "              powdered drink mixes       0.00      0.00      0.00         4\n",
      "                    kitchen safety       0.00      0.00      0.00         9\n",
      "            car seat stroller toys       0.00      0.00      0.00         8\n",
      "                           rattles       0.00      0.00      0.00         9\n",
      "                    air fresheners       0.00      0.00      0.00         8\n",
      "                 electrical safety       0.00      0.00      0.00         7\n",
      "                    makeup remover       0.00      0.00      0.00         9\n",
      "                 printing stamping       0.00      0.00      0.00         0\n",
      "                    pegged puzzles       0.00      0.00      0.00        22\n",
      "                 rails rail guards       0.00      0.00      0.00        10\n",
      "                       candy gifts       0.00      0.00      0.00         3\n",
      "                              sets       0.00      0.00      0.00         4\n",
      "                     hair relaxers       0.00      0.00      0.00         2\n",
      "                    walkie talkies       0.00      0.00      0.00         0\n",
      "                      baking mixes       0.00      0.00      0.00         5\n",
      "                  dessert toppings       0.00      0.00      0.00         2\n",
      "                             fudge       0.00      0.00      0.00         1\n",
      "                   bathroom safety       0.00      0.00      0.00         4\n",
      "                             jerky       0.00      0.00      0.00         0\n",
      "                         geography       0.00      0.00      0.00        14\n",
      "                           hobbies       0.00      0.00      0.00        20\n",
      "                        cough cold       0.00      0.00      0.00        20\n",
      "                       bubble bath       0.00      0.00      0.00         0\n",
      "                cameras camcorders       0.00      0.00      0.00         1\n",
      "                 sugars sweeteners       0.00      0.00      0.00         3\n",
      "                            cheese       0.00      0.00      0.00         2\n",
      "                          body art       0.00      0.00      0.00         2\n",
      "             scaled model vehicles       0.00      0.00      0.00         5\n",
      "       standard playing card decks       0.01      0.78      0.02         9\n",
      "                     wipes holders       0.00      0.00      0.00        16\n",
      "            scrubs body treatments       0.00      0.00      0.00        10\n",
      "                          stickers       0.00      0.00      0.00         2\n",
      "          gag toys practical jokes       0.00      0.00      0.00        20\n",
      "                       dried beans       0.00      0.00      0.00         9\n",
      "                          playards       0.00      0.00      0.00        13\n",
      "                   coin collecting       0.00      0.00      0.00         1\n",
      "                            breads       0.00      0.00      0.00         0\n",
      "            granola trail mix bars       0.00      0.00      0.00         0\n",
      "                  novelty gag toys       0.00      0.00      0.00        12\n",
      "                 sleep positioners       0.00      0.00      0.00         1\n",
      "                        stimulants       0.00      0.00      0.00         3\n",
      "                      fresh fruits       0.00      0.00      0.00         5\n",
      "                     brain teasers       0.00      0.00      0.00        10\n",
      "            portable changing pads       0.00      0.00      0.00         0\n",
      "            basic life skills toys       0.00      0.00      0.00         6\n",
      "                         trail mix       0.00      0.00      0.00         0\n",
      "             ball pits accessories       0.00      0.00      0.00         2\n",
      "                          pretzels       0.00      0.00      0.00         2\n",
      "                        clay dough       0.00      0.00      0.00         8\n",
      "                       teddy bears       0.00      0.00      0.00         4\n",
      "                     puffed snacks       0.00      0.00      0.00         4\n",
      "                    training pants       0.00      0.00      0.00         3\n",
      "                             cages       0.00      0.00      0.00         0\n",
      "                         tea gifts       0.00      0.00      0.00         4\n",
      "                          cleaners       0.00      0.00      0.00         8\n",
      "                            toffee       0.00      0.00      0.00         0\n",
      "                   reading writing       0.00      0.00      0.00         0\n",
      "                        bags cases       0.00      0.00      0.00         7\n",
      "          molding sculpting sticks       0.00      0.00      0.00         0\n",
      "                 jerky dried meats       0.00      0.00      0.00         4\n",
      "                             mints       0.00      0.00      0.00         8\n",
      "                  milk substitutes       0.00      0.00      0.00         0\n",
      "                       makeup sets       0.00      0.00      0.00         0\n",
      "                       gummy candy       0.00      0.00      0.00         6\n",
      "                              milk       0.00      0.00      0.00         5\n",
      "                  sexual enhancers       0.00      0.00      0.00         1\n",
      "             sandboxes accessories       0.00      0.00      0.00         2\n",
      "                        nuts seeds       0.00      0.00      0.00         2\n",
      "                     fruit leather       0.00      0.00      0.00         6\n",
      "                          lighters       0.00      0.00      0.00        21\n",
      "                        rice cakes       0.00      0.00      0.00         0\n",
      "                   game room games       0.00      0.00      0.00         4\n",
      "                             chips       0.00      0.00      0.00         1\n",
      "                 cages accessories       0.00      0.00      0.00        10\n",
      "                 fresh cut flowers       0.00      0.00      0.00         3\n",
      "                blasters foam play       0.00      0.00      0.00         7\n",
      "                        condiments       0.00      0.00      0.00         2\n",
      "                             cakes       0.00      0.00      0.00         4\n",
      "                    sun protection       0.00      0.00      0.00         1\n",
      "                    aquarium hoods       0.00      0.00      0.00         1\n",
      "                         chocolate       0.00      0.00      0.00         3\n",
      "                             bacon       0.00      0.00      0.00         0\n",
      "               kites wind spinners       0.00      0.00      0.00         1\n",
      "                       soft drinks       0.00      0.00      0.00         3\n",
      "                        dance mats       0.00      0.00      0.00         0\n",
      "             chocolate assortments       0.00      0.00      0.00         3\n",
      "                            salsas       0.00      0.00      0.00         2\n",
      "               pogo sticks hoppers       0.00      0.00      0.00         4\n",
      "               hair coloring tools       0.00      0.00      0.00         2\n",
      "                puzzle accessories       0.00      0.00      0.00         0\n",
      "                       money banks       0.00      0.00      0.00         1\n",
      "                   soaps cleansers       0.00      0.00      0.00         6\n",
      "                    chocolate bars       0.00      0.00      0.00        22\n",
      "                    stacking games       0.00      0.00      0.00        10\n",
      "                  fresh vegetables       0.00      0.00      0.00         2\n",
      "                             halva       0.00      0.00      0.00         0\n",
      "                non slip bath mats       0.00      0.00      0.00         0\n",
      "                          standard       0.00      0.00      0.00         3\n",
      "                   teaching clocks       0.00      0.00      0.00         2\n",
      "           diaper stackers caddies       0.00      0.00      0.00         4\n",
      "                   cocktail mixers       0.00      0.00      0.00         0\n",
      "                 sugar substitutes       0.00      0.00      0.00         1\n",
      "                 crackers biscuits       0.00      0.00      0.00         0\n",
      "                      wind up toys       0.00      0.00      0.00         3\n",
      "                             games       0.00      0.00      0.00         0\n",
      "                            butter       0.00      0.00      0.00         0\n",
      "                     seafood gifts       0.00      0.00      0.00         3\n",
      "                         d puzzles       0.00      0.00      0.00         4\n",
      "                      coffee gifts       0.00      0.00      0.00         0\n",
      "                    rabbit hutches       0.00      1.00      0.00         3\n",
      "                     pasta noodles       0.00      0.00      0.00         6\n",
      "                               gum       0.00      0.00      0.00         7\n",
      "                      pizza crusts       0.00      0.00      0.00         0\n",
      "                extracts flavoring       0.00      0.00      0.00         2\n",
      "                        baby seats       0.00      0.00      0.00         1\n",
      "                 fitness equipment       0.00      0.00      0.00         0\n",
      "                      marble games       0.00      0.00      0.00         0\n",
      "                  dice gaming dice       0.00      0.00      0.00         4\n",
      "                chocolate truffles       0.00      0.00      0.00         2\n",
      "                          carriers       0.00      0.00      0.00         2\n",
      "                       viewfinders       0.00      0.00      0.00         3\n",
      "                    handheld games       0.00      0.00      0.00         0\n",
      "                          caramels       0.00      0.00      0.00         6\n",
      "                     baking powder       0.00      0.00      0.00         0\n",
      "                              dips       0.00      0.00      0.00         2\n",
      "                    beauty fashion       0.00      0.00      0.00         1\n",
      "                           chicken       0.00      0.00      0.00         0\n",
      "               odor stain removers       0.00      0.00      0.00         0\n",
      "                      sauces gifts       0.00      0.00      0.00         2\n",
      "           hammering pounding toys       0.00      0.00      0.00         0\n",
      "                       flash cards       0.00      0.00      0.00         9\n",
      "   indoor climbers play structures       0.00      0.00      0.00         0\n",
      "                 washcloths towels       0.00      0.00      0.00         0\n",
      "                      nut clusters       0.00      0.00      0.00         2\n",
      "          shampoo conditioner sets       0.01      1.00      0.01        11\n",
      "                      cheese gifts       0.00      0.00      0.00         1\n",
      "                         toy banks       0.00      0.00      0.00         1\n",
      "                       chewing gum       0.00      0.00      0.00         4\n",
      "                       floor games       0.00      0.00      0.00         3\n",
      "                      marshmallows       0.00      0.00      0.00         0\n",
      "                chocolate pretzels       0.00      0.00      0.00         0\n",
      "              prisms kaleidoscopes       0.00      0.00      0.00         4\n",
      "                 automatic feeders       0.00      0.00      0.00         1\n",
      "                             tests       0.00      0.00      0.00         3\n",
      "                       step stools       0.00      0.00      0.00         3\n",
      "                     sex furniture       0.00      0.00      0.00         1\n",
      "                      flours meals       0.00      0.00      0.00         2\n",
      "                      cotton swabs       0.00      0.00      0.00         3\n",
      "          shampoo plus conditioner       0.00      0.00      0.00         5\n",
      "                     plush puppets       0.00      0.00      0.00         4\n",
      "                          pastries       0.00      0.00      0.00         0\n",
      "                  sensual delights       0.00      0.00      0.00         1\n",
      "           thermometer accessories       0.00      0.00      0.00         1\n",
      "                        pork rinds       0.00      0.00      0.00         1\n",
      "                            cereal       0.00      0.00      0.00         1\n",
      "                           dinners       0.00      0.00      0.00         0\n",
      "               dried fruit raisins       0.00      0.00      0.00         1\n",
      "                         beverages       0.00      0.00      0.00         0\n",
      "                      travel games       0.00      0.00      0.00         0\n",
      "                       fruit gifts       0.00      0.00      0.00         0\n",
      "             aquarium starter kits       0.00      0.00      0.00         2\n",
      "                       electronics       0.00      0.00      0.00         1\n",
      "                     toy gift sets       0.00      0.00      0.00         0\n",
      "                   foie gras p t s       0.00      0.00      0.00         0\n",
      "                        children s       0.00      0.00      0.00         1\n",
      "                     aprons smocks       0.00      0.00      0.00         3\n",
      "                            fruits       0.00      0.00      0.00         1\n",
      "          cloth diaper accessories       0.00      0.00      0.00         0\n",
      "                            syrups       0.00      0.00      0.00         0\n",
      "                 temporary tattoos       0.00      0.00      0.00         0\n",
      "                         slot cars       0.00      0.00      0.00         0\n",
      "                          stuffing       0.00      0.00      0.00         1\n",
      "                     sports drinks       0.00      0.00      0.00         1\n",
      "                             prams       0.00      0.00      0.00         0\n",
      "                       breadsticks       0.00      0.00      0.00         2\n",
      "                      exotic meats       0.00      0.00      0.00         0\n",
      "                    wild game fowl       0.00      0.00      0.00         0\n",
      "          bondage gear accessories       0.00      0.00      0.00         5\n",
      "                       breadcrumbs       0.00      0.00      0.00         0\n",
      "                     food coloring       0.00      0.00      0.00         2\n",
      "                     plush pillows       0.00      0.00      0.00         0\n",
      "                             beads       0.00      0.00      0.00         0\n",
      "                         novelties       0.00      0.00      0.00         2\n",
      "              jams preserves gifts       0.00      0.00      0.00         0\n",
      "                          sausages       0.00      0.00      0.00         0\n",
      "                        hair nails       0.00      0.00      0.00         0\n",
      "                   chocolate gifts       0.00      0.00      0.00         3\n",
      "                pastry decorations       0.00      0.00      0.00         0\n",
      "             die cast toy vehicles       0.00      0.00      0.00         1\n",
      "                           raisins       0.00      0.00      0.00         0\n",
      "           chocolate covered fruit       0.00      0.00      0.00         3\n",
      "                  slime putty toys       0.00      0.00      0.00         0\n",
      "                            yo yos       0.00      0.00      0.00         0\n",
      "                         tortillas       0.00      0.00      0.00         0\n",
      "                         memorials       0.00      0.00      0.00         0\n",
      "                     nesting dolls       0.00      0.00      0.00         0\n",
      "                              pork       0.00      0.00      0.00         1\n",
      "                              beef       0.00      0.00      0.00         0\n",
      "                  puzzle play mats       0.00      0.00      0.00         2\n",
      "                  game collections       0.00      0.00      0.00         3\n",
      "         kickball playground balls       0.00      0.00      0.00         0\n",
      "                        fish bowls       0.00      0.00      0.00         0\n",
      "             novelty spinning tops       0.00      0.00      0.00         1\n",
      "                             p t s       0.00      0.00      0.00         0\n",
      "                     spinning tops       0.00      0.00      0.00         0\n",
      "                        meat gifts       0.00      0.00      0.00         2\n",
      "                      slumber bags       0.00      0.00      0.00         0\n",
      "                      granola bars       0.00      0.00      0.00         0\n",
      "                              eggs       0.00      0.00      0.00         1\n",
      "                      aromatherapy       0.00      0.00      0.00         0\n",
      "                       dried fruit       0.00      0.00      0.00         1\n",
      "                       flying toys       0.00      0.00      0.00         2\n",
      "                           shampoo       0.00      0.00      0.00         1\n",
      "                  coatings batters       0.00      0.00      0.00         1\n",
      "                       hydrometers       0.00      0.00      0.00         1\n",
      "                              lamb       0.00      0.00      0.00         1\n",
      "                   exercise wheels       0.00      0.00      0.00         1\n",
      "            chocolate covered nuts       0.00      0.00      0.00         2\n",
      "                    breeding tanks       0.00      0.00      0.00         1\n",
      "\n",
      "                         micro avg       0.00      0.00      0.00      9513\n",
      "                         macro avg       0.00      0.01      0.00      9513\n",
      "                      weighted avg       0.00      0.00      0.00      9513\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/usr/local/lib/python3.8/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/usr/local/lib/python3.8/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/usr/local/lib/python3.8/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/usr/local/lib/python3.8/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/usr/local/lib/python3.8/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "class_names = list(cat3_map.keys())\n",
    "\n",
    "print(classification_report(emo_y_test, emo_pred, labels=np.arange(0,len(class_names),1), target_names=class_names))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GhTpYmfdiF-P"
   },
   "source": [
    "Reverse Native Injection Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1626,
     "status": "ok",
     "timestamp": 1695647602015,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "pOutPTF6qJY8",
    "outputId": "9849d9c0-c3d2-4e61-d3e9-db9b8827f946",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.031746031746031744"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zero_emb,one_emb,two_emb,three_emb,four_emb,five_emb = one_emb,two_emb,three_emb,four_emb,five_emb, zero_emb\n",
    "\n",
    "\n",
    "test_acc, _ = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "test_acc.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oEdEBFPQiasC"
   },
   "source": [
    "Paraphrased Native Injection Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12431,
     "status": "ok",
     "timestamp": 1695647676127,
     "user": {
      "displayName": "m m",
      "userId": "12156804663229931259"
     },
     "user_tz": -360
    },
    "id": "uEuhYVF-gMqC",
    "outputId": "ff1c54fd-94a2-4bd6-cdba-5f901d60f735",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7780931357090297"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zero_emb = make_emb(label_cop_def[0][3])\n",
    "one_emb = make_emb(label_cop_def[1][3])\n",
    "two_emb = make_emb(label_cop_def[2][3])\n",
    "three_emb = make_emb(label_cop_def[3][3])\n",
    "four_emb = make_emb(label_cop_def[4][3])\n",
    "five_emb = make_emb(label_cop_def[5][3])\n",
    "\n",
    "\n",
    "test_acc, _ = eval_model(\n",
    "  model,\n",
    "  test_data_loader,\n",
    "  loss_fn,\n",
    "  device,\n",
    "  len(df_test)\n",
    ")\n",
    "\n",
    "test_acc.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "02c5d9c00c8146319f3bdbd2efb850e9": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "1c904dd6194b45ad80b5350554be7783": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "26c2b2a6356f4c20828c4aa654b06187": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "371030e0df764b32b41dece67af8774a": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_7a097b6b084f4fcdb80162b2eaf0318a",
       "IPY_MODEL_f949ddf7ee5748328238adcb4e6be853",
       "IPY_MODEL_8d1b4a1ce3f44d698c1d7873917091b5"
      ],
      "layout": "IPY_MODEL_1c904dd6194b45ad80b5350554be7783"
     }
    },
    "7a097b6b084f4fcdb80162b2eaf0318a": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_fee8eb44bf6a452ab5b83eb74c76af41",
      "placeholder": "​",
      "style": "IPY_MODEL_bf31fedc33724b2f98b8679ddc4f1f03",
      "value": "Downloading model.safetensors: 100%"
     }
    },
    "8d1b4a1ce3f44d698c1d7873917091b5": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_bb66847397264ee3989babdc4ce0fa19",
      "placeholder": "​",
      "style": "IPY_MODEL_26c2b2a6356f4c20828c4aa654b06187",
      "value": " 440M/440M [00:02&lt;00:00, 165MB/s]"
     }
    },
    "bb66847397264ee3989babdc4ce0fa19": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "bbd49f0e4a91459fa34d6d5cef8771df": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "bf31fedc33724b2f98b8679ddc4f1f03": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "f949ddf7ee5748328238adcb4e6be853": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_bbd49f0e4a91459fa34d6d5cef8771df",
      "max": 440449768,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_02c5d9c00c8146319f3bdbd2efb850e9",
      "value": 440449768
     }
    },
    "fee8eb44bf6a452ab5b83eb74c76af41": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
