{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "module_path\n",
    "sys.path.append(module_path)\n",
    "\n",
    "from utils import load_word_embedding, Vocab_Dataset, predict_mse_with_Attn\n",
    "from ournet_config import Config\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "processing dataset - vocab_category_u : vocab_1\n",
      "processing dataset - vocab_category_i : vocab_1\n"
     ]
    }
   ],
   "source": [
    "config = Config()\n",
    "word_emb, word_dict = load_word_embedding('../data/preprocessing/glove.6B.50d.txt')\n",
    "\n",
    "test_dataset =  Vocab_Dataset('../prompt/dataset_vocab/roberta_AM_test.pkl', word_dict, config)\n",
    "test_dlr = DataLoader(test_dataset, batch_size=463)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### inference for visualizing Attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ournet_AM_example\n",
    "_model = torch.load('../model/model_save/Ournet_AM_example.pt', map_location=torch.device('cpu'))\n",
    "model = _model.to(torch.device('cpu'))\n",
    "\n",
    "def test(dataloader, model):\n",
    "\n",
    "    uid, iid, ratings, predict, score_u, score_i, first_u_gattn, first_i_gattn = predict_mse_with_Attn(model, dataloader)\n",
    "    return uid, iid, ratings, predict, score_u, score_i, first_u_gattn, first_i_gattn\n",
    "\n",
    "li = test(test_dlr, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_pickle('../prompt/dataset_vocab/roberta_AM_test.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userID</th>\n",
       "      <th>itemID</th>\n",
       "      <th>review</th>\n",
       "      <th>rating</th>\n",
       "      <th>user_reviews_concat</th>\n",
       "      <th>item_reviews_concat</th>\n",
       "      <th>user_vocab_1</th>\n",
       "      <th>user_vocab_2</th>\n",
       "      <th>item_vocab_1</th>\n",
       "      <th>item_vocab_2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2726</td>\n",
       "      <td>1686</td>\n",
       "      <td>i just put on the wipers and after two weeks a...</td>\n",
       "      <td>3</td>\n",
       "      <td>i've had several tire pressure gauges over the...</td>\n",
       "      <td>these are the best they say they have lots of ...</td>\n",
       "      <td>[price, quality, warranty, cost, reader, money...</td>\n",
       "      <td>[quality, convenience, best, product, warranty...</td>\n",
       "      <td>[price, quality, design, size, description, co...</td>\n",
       "      <td>[simplicity, convenience, warranty, safety, pr...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2727</td>\n",
       "      <td>782</td>\n",
       "      <td>of course these were purchased to top the lynx...</td>\n",
       "      <td>5</td>\n",
       "      <td>our travel trailer is small two adults and 2 c...</td>\n",
       "      <td>this is a good durable product for you travel ...</td>\n",
       "      <td>[price, size, cost, liner, warranty, design, q...</td>\n",
       "      <td>[product, best, experience, environment, conve...</td>\n",
       "      <td>[durability, quality, price, size, safety, wei...</td>\n",
       "      <td>[durability, versatility, safety, trailer, sta...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>100</td>\n",
       "      <td>990</td>\n",
       "      <td>the hitch pin was received as indicated and no...</td>\n",
       "      <td>5</td>\n",
       "      <td>the plastic and finish is cheap i got adhesive...</td>\n",
       "      <td>i see in some of the reviews that a few people...</td>\n",
       "      <td>[price, quality, packaging, finish, warranty, ...</td>\n",
       "      <td>[quality, product, best, experience, price, co...</td>\n",
       "      <td>[quality, price, description, size, review, co...</td>\n",
       "      <td>[durability, price, warranty, beauty, simplici...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1768</td>\n",
       "      <td>1277</td>\n",
       "      <td>works great has plenty of power to fill my tir...</td>\n",
       "      <td>5</td>\n",
       "      <td>this pan is big and it catches all the drips s...</td>\n",
       "      <td>a lot of people are saying that this unit blow...</td>\n",
       "      <td>[size, picture, construction, lid, cleanup, sm...</td>\n",
       "      <td>[product, environment, taste, convenience, qua...</td>\n",
       "      <td>[price, quality, size, description, purchase, ...</td>\n",
       "      <td>[convenience, battery, simplicity, price, safe...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1239</td>\n",
       "      <td>916</td>\n",
       "      <td>these grab handles fit the bill perfectly i mo...</td>\n",
       "      <td>5</td>\n",
       "      <td>of course you have the kn name honestly it's n...</td>\n",
       "      <td>snap to install and seem sturdy as well well w...</td>\n",
       "      <td>[price, durability, texture, name, simplicity,...</td>\n",
       "      <td>[product, experience, environment, quality, be...</td>\n",
       "      <td>[ability, quality, size, price, installation, ...</td>\n",
       "      <td>[simplicity, durability, installation, warrant...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   userID  itemID                                             review  rating  \\\n",
       "0    2726    1686  i just put on the wipers and after two weeks a...       3   \n",
       "1    2727     782  of course these were purchased to top the lynx...       5   \n",
       "2     100     990  the hitch pin was received as indicated and no...       5   \n",
       "3    1768    1277  works great has plenty of power to fill my tir...       5   \n",
       "4    1239     916  these grab handles fit the bill perfectly i mo...       5   \n",
       "\n",
       "                                 user_reviews_concat  \\\n",
       "0  i've had several tire pressure gauges over the...   \n",
       "1  our travel trailer is small two adults and 2 c...   \n",
       "2  the plastic and finish is cheap i got adhesive...   \n",
       "3  this pan is big and it catches all the drips s...   \n",
       "4  of course you have the kn name honestly it's n...   \n",
       "\n",
       "                                 item_reviews_concat  \\\n",
       "0  these are the best they say they have lots of ...   \n",
       "1  this is a good durable product for you travel ...   \n",
       "2  i see in some of the reviews that a few people...   \n",
       "3  a lot of people are saying that this unit blow...   \n",
       "4  snap to install and seem sturdy as well well w...   \n",
       "\n",
       "                                        user_vocab_1  \\\n",
       "0  [price, quality, warranty, cost, reader, money...   \n",
       "1  [price, size, cost, liner, warranty, design, q...   \n",
       "2  [price, quality, packaging, finish, warranty, ...   \n",
       "3  [size, picture, construction, lid, cleanup, sm...   \n",
       "4  [price, durability, texture, name, simplicity,...   \n",
       "\n",
       "                                        user_vocab_2  \\\n",
       "0  [quality, convenience, best, product, warranty...   \n",
       "1  [product, best, experience, environment, conve...   \n",
       "2  [quality, product, best, experience, price, co...   \n",
       "3  [product, environment, taste, convenience, qua...   \n",
       "4  [product, experience, environment, quality, be...   \n",
       "\n",
       "                                        item_vocab_1  \\\n",
       "0  [price, quality, design, size, description, co...   \n",
       "1  [durability, quality, price, size, safety, wei...   \n",
       "2  [quality, price, description, size, review, co...   \n",
       "3  [price, quality, size, description, purchase, ...   \n",
       "4  [ability, quality, size, price, installation, ...   \n",
       "\n",
       "                                        item_vocab_2  \n",
       "0  [simplicity, convenience, warranty, safety, pr...  \n",
       "1  [durability, versatility, safety, trailer, sta...  \n",
       "2  [durability, price, warranty, beauty, simplici...  \n",
       "3  [convenience, battery, simplicity, price, safe...  \n",
       "4  [simplicity, durability, installation, warrant...  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_view = pd.concat([df,pd.DataFrame(li[3].squeeze())], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index([             'userID',              'itemID',              'review',\n",
       "                    'rating', 'user_reviews_concat', 'item_reviews_concat',\n",
       "              'user_vocab_1',        'user_vocab_2',        'item_vocab_1',\n",
       "              'item_vocab_2',                     0],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_view.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_colwidth', 500)\n",
    "df_view =df_view[['userID', 'itemID', 'user_vocab_1', 'item_vocab_1', 'rating', 0]]\n",
    "df_view = df_view.rename(columns={0:'predicting'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_view['user_vocab_1'] = df_view['user_vocab_1'].apply(lambda x : x[:10])\n",
    "df_view['item_vocab_1'] = df_view['item_vocab_1'].apply(lambda x : x[:10])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userID</th>\n",
       "      <th>itemID</th>\n",
       "      <th>user_vocab_1</th>\n",
       "      <th>item_vocab_1</th>\n",
       "      <th>rating</th>\n",
       "      <th>predicting</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2727</td>\n",
       "      <td>782</td>\n",
       "      <td>[price, size, cost, liner, warranty, design, quality, lining, simplicity, detail]</td>\n",
       "      <td>[durability, quality, price, size, safety, weight, versatility, value, longevity, condition]</td>\n",
       "      <td>5</td>\n",
       "      <td>4.867499</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1768</td>\n",
       "      <td>1277</td>\n",
       "      <td>[size, picture, construction, lid, cleanup, smell, warranty, recipe, design, storage]</td>\n",
       "      <td>[price, quality, size, description, purchase, cost, pricing, safety, rating, ability]</td>\n",
       "      <td>5</td>\n",
       "      <td>4.849397</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1239</td>\n",
       "      <td>916</td>\n",
       "      <td>[price, durability, texture, name, simplicity, detail, packaging, design, thickness, warranty]</td>\n",
       "      <td>[ability, quality, size, price, installation, ing, ness, durability, design, cost]</td>\n",
       "      <td>5</td>\n",
       "      <td>4.915260</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>976</td>\n",
       "      <td>17</td>\n",
       "      <td>[price, packaging, simplicity, quality, link, documentation, warranty, design, cost, following]</td>\n",
       "      <td>[price, size, safety, quality, age, weight, condition, cost, durability, value]</td>\n",
       "      <td>4</td>\n",
       "      <td>4.864094</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>525</td>\n",
       "      <td>1467</td>\n",
       "      <td>[texture, detail, process, consistency, smell, tip, color, simplicity, finish, picture]</td>\n",
       "      <td>[price, quality, description, review, use, importance, purchase, safety, cost, taste]</td>\n",
       "      <td>5</td>\n",
       "      <td>4.925333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>446</th>\n",
       "      <td>2477</td>\n",
       "      <td>173</td>\n",
       "      <td>[price, packaging, pricing, quality, cost, rating, simplicity, size, warranty, volume]</td>\n",
       "      <td>[quality, price, taste, cost, size, value, purchase, pricing, name, ness]</td>\n",
       "      <td>5</td>\n",
       "      <td>5.372795</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>449</th>\n",
       "      <td>1407</td>\n",
       "      <td>551</td>\n",
       "      <td>[price, safety, length, size, weight, warranty, design, quality, cost, durability]</td>\n",
       "      <td>[size, quality, versatility, price, description, ability, safety, simplicity, value, purchase]</td>\n",
       "      <td>5</td>\n",
       "      <td>4.935380</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>452</th>\n",
       "      <td>2362</td>\n",
       "      <td>1542</td>\n",
       "      <td>[price, balance, size, speed, handle, weight, design, durability, grip, following]</td>\n",
       "      <td>[price, quality, description, value, purchase, review, ability, size, importance, pricing]</td>\n",
       "      <td>5</td>\n",
       "      <td>5.093006</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>455</th>\n",
       "      <td>974</td>\n",
       "      <td>932</td>\n",
       "      <td>[price, cost, info, simplicity, pricing, money, link, information, quality, packaging]</td>\n",
       "      <td>[quality, size, weight, price, safety, use, durability, handling, cost, design]</td>\n",
       "      <td>5</td>\n",
       "      <td>4.854034</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>459</th>\n",
       "      <td>536</td>\n",
       "      <td>977</td>\n",
       "      <td>[price, simplicity, packaging, design, color, branding, picture, camouflage, name, styling]</td>\n",
       "      <td>[quality, price, condition, cost, size, value, safety, purchase, use, weight]</td>\n",
       "      <td>5</td>\n",
       "      <td>4.825294</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>121 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     userID  itemID  \\\n",
       "1      2727     782   \n",
       "3      1768    1277   \n",
       "4      1239     916   \n",
       "21      976      17   \n",
       "24      525    1467   \n",
       "..      ...     ...   \n",
       "446    2477     173   \n",
       "449    1407     551   \n",
       "452    2362    1542   \n",
       "455     974     932   \n",
       "459     536     977   \n",
       "\n",
       "                                                                                        user_vocab_1  \\\n",
       "1                  [price, size, cost, liner, warranty, design, quality, lining, simplicity, detail]   \n",
       "3              [size, picture, construction, lid, cleanup, smell, warranty, recipe, design, storage]   \n",
       "4     [price, durability, texture, name, simplicity, detail, packaging, design, thickness, warranty]   \n",
       "21   [price, packaging, simplicity, quality, link, documentation, warranty, design, cost, following]   \n",
       "24           [texture, detail, process, consistency, smell, tip, color, simplicity, finish, picture]   \n",
       "..                                                                                               ...   \n",
       "446           [price, packaging, pricing, quality, cost, rating, simplicity, size, warranty, volume]   \n",
       "449               [price, safety, length, size, weight, warranty, design, quality, cost, durability]   \n",
       "452               [price, balance, size, speed, handle, weight, design, durability, grip, following]   \n",
       "455           [price, cost, info, simplicity, pricing, money, link, information, quality, packaging]   \n",
       "459      [price, simplicity, packaging, design, color, branding, picture, camouflage, name, styling]   \n",
       "\n",
       "                                                                                       item_vocab_1  \\\n",
       "1      [durability, quality, price, size, safety, weight, versatility, value, longevity, condition]   \n",
       "3             [price, quality, size, description, purchase, cost, pricing, safety, rating, ability]   \n",
       "4                [ability, quality, size, price, installation, ing, ness, durability, design, cost]   \n",
       "21                  [price, size, safety, quality, age, weight, condition, cost, durability, value]   \n",
       "24            [price, quality, description, review, use, importance, purchase, safety, cost, taste]   \n",
       "..                                                                                              ...   \n",
       "446                       [quality, price, taste, cost, size, value, purchase, pricing, name, ness]   \n",
       "449  [size, quality, versatility, price, description, ability, safety, simplicity, value, purchase]   \n",
       "452      [price, quality, description, value, purchase, review, ability, size, importance, pricing]   \n",
       "455                 [quality, size, weight, price, safety, use, durability, handling, cost, design]   \n",
       "459                   [quality, price, condition, cost, size, value, safety, purchase, use, weight]   \n",
       "\n",
       "     rating  predicting  \n",
       "1         5    4.867499  \n",
       "3         5    4.849397  \n",
       "4         5    4.915260  \n",
       "21        4    4.864094  \n",
       "24        5    4.925333  \n",
       "..      ...         ...  \n",
       "446       5    5.372795  \n",
       "449       5    4.935380  \n",
       "452       5    5.093006  \n",
       "455       5    4.854034  \n",
       "459       5    4.825294  \n",
       "\n",
       "[121 rows x 6 columns]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_view[df_view['predicting'] > 4.8]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0077, 0.0295, 0.0040, 0.0042, 0.0029, 0.0186, 0.0267, 0.0031, 0.0064,\n",
      "        0.0036])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.0997, 0.1019, 0.0993, 0.0994, 0.0992, 0.1008, 0.1016, 0.0992, 0.0996,\n",
       "        0.0993])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print((li[4].squeeze() * li[6])[1])\n",
    "torch.softmax(li[4].squeeze() * li[6], dim=1)[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0097, 0.0281, 0.0061, 0.0291, 0.0116, 0.0094, 0.0032, 0.0026, 0.0026,\n",
      "        0.0053])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.0999, 0.1017, 0.0995, 0.1018, 0.1001, 0.0999, 0.0992, 0.0992, 0.0992,\n",
       "        0.0995])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print((li[5].squeeze() * li[7])[1])\n",
    "torch.softmax(li[5].squeeze() * li[7], dim=1)[1]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
