{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b2479ab1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-a6f8dd9b-6af8-5e75-8654-84fb2b7b8f6d\"\n",
    "\n",
    "from model.ours2 import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bce7f995",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = {}\n",
    "\n",
    "#basic setting\n",
    "options['use_cuda'] = True\n",
    "options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/nell/vocab/'\n",
    "options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/nell/'\n",
    "options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "options['model_dir'] = './outputs_nell995-best/'\n",
    "options['output_dir'] = './outputs_nell995-best/'\n",
    "\n",
    "#agent setting\n",
    "options['pretrained_embeddings_relation'] = {}\n",
    "options['pretrained_embeddings_entity'] = {}\n",
    "options['embedding_size'] = 50\n",
    "options['hidden_size'] = 200\n",
    "options['use_entity_embeddings'] = 1\n",
    "options['train_entity_embeddings'] = 0\n",
    "options['train_relation_embeddings'] = 1\n",
    "options['path_length'] = 3\n",
    "options['LSTM_layers'] = 1\n",
    "options['max_num_actions'] = 40\n",
    "options['gnn_layer'] = 1\n",
    "\n",
    "#hyperparameters\n",
    "options['test_rollouts'] = 40\n",
    "options['num_rollouts'] = 20\n",
    "options['batch_size'] = 64\n",
    "options['eval_batch_size'] = 32\n",
    "options['beta'] = 0.12\n",
    "options['Lambda'] = 0.12\n",
    "options['gamma'] = 1\n",
    "options['positive_reward'] = 1\n",
    "options['negative_reward'] = 0\n",
    "options['learning_rate'] = 0.0001\n",
    "options['grad_clip_norm'] = 100\n",
    "options['eval_every'] = 100\n",
    "options['total_iterations'] = 2000*(64/options['batch_size'])\n",
    "options['pool'] = 'max'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9e8e25fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/experiments/model/ours2.py:333: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.2764, rewards: 0.1364\n",
      "Iteration: 20, Train loss: -0.2215, rewards: 0.2054\n",
      "Iteration: 30, Train loss: -0.2444, rewards: 0.2976\n",
      "Iteration: 40, Train loss: -0.2671, rewards: 0.3142\n",
      "Iteration: 50, Train loss: -0.2551, rewards: 0.3362\n",
      "Iteration: 60, Train loss: -0.2231, rewards: 0.3626\n",
      "Iteration: 70, Train loss: -0.2356, rewards: 0.3973\n",
      "Iteration: 80, Train loss: -0.1697, rewards: 0.4205\n",
      "Iteration: 90, Train loss: -0.1550, rewards: 0.4505\n",
      "Iteration: 100, Train loss: -0.2383, rewards: 0.4636\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/Research/GraphRL/experiments/model/ours2.py:635: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.5562, Hits@3: 0.6243, Hits@10: 0.6611, MRR: 0.5962\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.2203, rewards: 0.4409\n",
      "Iteration: 120, Train loss: -0.1875, rewards: 0.4639\n",
      "Iteration: 130, Train loss: -0.1865, rewards: 0.4543\n",
      "Iteration: 140, Train loss: -0.1564, rewards: 0.4845\n",
      "Iteration: 150, Train loss: -0.2154, rewards: 0.4780\n",
      "Iteration: 160, Train loss: -0.1591, rewards: 0.5198\n",
      "Iteration: 170, Train loss: -0.1308, rewards: 0.4790\n",
      "Iteration: 180, Train loss: -0.1768, rewards: 0.4822\n",
      "Iteration: 190, Train loss: -0.1334, rewards: 0.5132\n",
      "Iteration: 200, Train loss: -0.1712, rewards: 0.5198\n",
      "Eval:\n",
      "Hits@1: 0.5709, Hits@3: 0.6390, Hits@10: 0.6611, MRR: 0.6069\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.1252, rewards: 0.5065\n",
      "Iteration: 220, Train loss: -0.1841, rewards: 0.5040\n",
      "Iteration: 230, Train loss: -0.1933, rewards: 0.4885\n",
      "Iteration: 240, Train loss: -0.1848, rewards: 0.4985\n",
      "Iteration: 250, Train loss: -0.0968, rewards: 0.4992\n",
      "Iteration: 260, Train loss: -0.1103, rewards: 0.5388\n",
      "Iteration: 270, Train loss: -0.1201, rewards: 0.5025\n",
      "Iteration: 280, Train loss: -0.1307, rewards: 0.5563\n",
      "Iteration: 290, Train loss: -0.1271, rewards: 0.5130\n",
      "Iteration: 300, Train loss: -0.0963, rewards: 0.5365\n",
      "Eval:\n",
      "Hits@1: 0.5930, Hits@3: 0.6464, Hits@10: 0.6630, MRR: 0.6205\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.1164, rewards: 0.5116\n",
      "Iteration: 320, Train loss: -0.1512, rewards: 0.5150\n",
      "Iteration: 330, Train loss: -0.1664, rewards: 0.5114\n",
      "Iteration: 340, Train loss: -0.1065, rewards: 0.5212\n",
      "Iteration: 350, Train loss: -0.1235, rewards: 0.5377\n",
      "Iteration: 360, Train loss: -0.1073, rewards: 0.4952\n",
      "Iteration: 370, Train loss: -0.0997, rewards: 0.5110\n",
      "Iteration: 380, Train loss: -0.0715, rewards: 0.5555\n",
      "Iteration: 390, Train loss: -0.1074, rewards: 0.5135\n",
      "Iteration: 400, Train loss: -0.1313, rewards: 0.5257\n",
      "Eval:\n",
      "Hits@1: 0.5875, Hits@3: 0.6501, Hits@10: 0.6685, MRR: 0.6200\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.1311, rewards: 0.5645\n",
      "Iteration: 420, Train loss: -0.0899, rewards: 0.5175\n",
      "Iteration: 430, Train loss: -0.0893, rewards: 0.5238\n",
      "Iteration: 440, Train loss: -0.0880, rewards: 0.5354\n",
      "Iteration: 450, Train loss: -0.1455, rewards: 0.5026\n",
      "Iteration: 460, Train loss: -0.1577, rewards: 0.5472\n",
      "Iteration: 470, Train loss: -0.1513, rewards: 0.5895\n",
      "Iteration: 480, Train loss: -0.0791, rewards: 0.5469\n",
      "Iteration: 490, Train loss: -0.0230, rewards: 0.5762\n",
      "Iteration: 500, Train loss: -0.0692, rewards: 0.5154\n",
      "Eval:\n",
      "Hits@1: 0.5764, Hits@3: 0.6317, Hits@10: 0.6648, MRR: 0.6085\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.1235, rewards: 0.5396\n",
      "Iteration: 520, Train loss: -0.1071, rewards: 0.5115\n",
      "Iteration: 530, Train loss: -0.1219, rewards: 0.5663\n",
      "Iteration: 540, Train loss: -0.0353, rewards: 0.5242\n",
      "Iteration: 550, Train loss: -0.1085, rewards: 0.5073\n",
      "Iteration: 560, Train loss: -0.1188, rewards: 0.5656\n",
      "Iteration: 570, Train loss: -0.0644, rewards: 0.5323\n",
      "Iteration: 580, Train loss: -0.0709, rewards: 0.5323\n",
      "Iteration: 590, Train loss: -0.1689, rewards: 0.5566\n",
      "Iteration: 600, Train loss: -0.1351, rewards: 0.5481\n",
      "Eval:\n",
      "Hits@1: 0.5783, Hits@3: 0.6427, Hits@10: 0.6648, MRR: 0.6113\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.0665, rewards: 0.5618\n",
      "Iteration: 620, Train loss: -0.1589, rewards: 0.5559\n",
      "Iteration: 630, Train loss: -0.1248, rewards: 0.5750\n",
      "Iteration: 640, Train loss: -0.1355, rewards: 0.5187\n",
      "Iteration: 650, Train loss: -0.0768, rewards: 0.5377\n",
      "Iteration: 660, Train loss: -0.0934, rewards: 0.5663\n",
      "Iteration: 670, Train loss: -0.1278, rewards: 0.5712\n",
      "Iteration: 680, Train loss: -0.1171, rewards: 0.5641\n",
      "Iteration: 690, Train loss: -0.0773, rewards: 0.5499\n",
      "Iteration: 700, Train loss: -0.0651, rewards: 0.5590\n",
      "Eval:\n",
      "Hits@1: 0.5764, Hits@3: 0.6409, Hits@10: 0.6740, MRR: 0.6123\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.0624, rewards: 0.5820\n",
      "Iteration: 720, Train loss: -0.0544, rewards: 0.5607\n",
      "Iteration: 730, Train loss: -0.1019, rewards: 0.5743\n",
      "Iteration: 740, Train loss: -0.0974, rewards: 0.5744\n",
      "Iteration: 750, Train loss: -0.1423, rewards: 0.5680\n",
      "Iteration: 760, Train loss: -0.0649, rewards: 0.5643\n",
      "Iteration: 770, Train loss: -0.0741, rewards: 0.5409\n",
      "Iteration: 780, Train loss: -0.0929, rewards: 0.5289\n",
      "Iteration: 790, Train loss: -0.1354, rewards: 0.5506\n",
      "Iteration: 800, Train loss: -0.1128, rewards: 0.5531\n",
      "Eval:\n",
      "Hits@1: 0.5801, Hits@3: 0.6372, Hits@10: 0.6703, MRR: 0.6143\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.1018, rewards: 0.5532\n",
      "Iteration: 820, Train loss: -0.0725, rewards: 0.5472\n",
      "Iteration: 830, Train loss: -0.0983, rewards: 0.5892\n",
      "Iteration: 840, Train loss: -0.0485, rewards: 0.5591\n",
      "Iteration: 850, Train loss: -0.1043, rewards: 0.5429\n",
      "Iteration: 860, Train loss: -0.0662, rewards: 0.5587\n",
      "Iteration: 870, Train loss: -0.1600, rewards: 0.5717\n",
      "Iteration: 880, Train loss: -0.0668, rewards: 0.5439\n",
      "Iteration: 890, Train loss: -0.0673, rewards: 0.5704\n",
      "Iteration: 900, Train loss: -0.0670, rewards: 0.5887\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6372, Hits@10: 0.6703, MRR: 0.6172\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.0398, rewards: 0.5568\n",
      "Iteration: 920, Train loss: -0.0660, rewards: 0.5831\n",
      "Iteration: 930, Train loss: -0.0021, rewards: 0.5255\n",
      "Iteration: 940, Train loss: -0.0962, rewards: 0.5752\n",
      "Iteration: 950, Train loss: -0.0834, rewards: 0.5752\n",
      "Iteration: 960, Train loss: -0.0665, rewards: 0.5748\n",
      "Iteration: 970, Train loss: -0.1242, rewards: 0.5852\n",
      "Iteration: 980, Train loss: -0.0475, rewards: 0.5598\n",
      "Iteration: 990, Train loss: -0.0422, rewards: 0.5768\n",
      "Iteration: 1000, Train loss: -0.0571, rewards: 0.6030\n",
      "Eval:\n",
      "Hits@1: 0.5930, Hits@3: 0.6446, Hits@10: 0.6667, MRR: 0.6211\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.0166, rewards: 0.5655\n",
      "Iteration: 1020, Train loss: -0.1058, rewards: 0.5798\n",
      "Iteration: 1030, Train loss: -0.1040, rewards: 0.5798\n",
      "Iteration: 1040, Train loss: -0.0720, rewards: 0.5472\n",
      "Iteration: 1050, Train loss: -0.0611, rewards: 0.5399\n",
      "Iteration: 1060, Train loss: -0.1266, rewards: 0.5841\n",
      "Iteration: 1070, Train loss: -0.0658, rewards: 0.5970\n",
      "Iteration: 1080, Train loss: -0.1176, rewards: 0.5583\n",
      "Iteration: 1090, Train loss: -0.1056, rewards: 0.5701\n",
      "Iteration: 1100, Train loss: -0.0920, rewards: 0.5780\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6427, Hits@10: 0.6740, MRR: 0.6177\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.0382, rewards: 0.5750\n",
      "Iteration: 1120, Train loss: -0.0989, rewards: 0.5974\n",
      "Iteration: 1130, Train loss: -0.0959, rewards: 0.5794\n",
      "Iteration: 1140, Train loss: -0.0727, rewards: 0.5870\n",
      "Iteration: 1150, Train loss: -0.0464, rewards: 0.5689\n",
      "Iteration: 1160, Train loss: -0.1202, rewards: 0.5574\n",
      "Iteration: 1170, Train loss: -0.0214, rewards: 0.5369\n",
      "Iteration: 1180, Train loss: -0.0947, rewards: 0.5285\n",
      "Iteration: 1190, Train loss: -0.0374, rewards: 0.5783\n",
      "Iteration: 1200, Train loss: -0.0812, rewards: 0.5578\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6409, Hits@10: 0.6685, MRR: 0.6192\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.0668, rewards: 0.5551\n",
      "Iteration: 1220, Train loss: -0.0119, rewards: 0.5895\n",
      "Iteration: 1230, Train loss: -0.0668, rewards: 0.5684\n",
      "Iteration: 1240, Train loss: -0.0892, rewards: 0.5795\n",
      "Iteration: 1250, Train loss: -0.0842, rewards: 0.6140\n",
      "Iteration: 1260, Train loss: -0.0484, rewards: 0.5706\n",
      "Iteration: 1270, Train loss: -0.0405, rewards: 0.5874\n",
      "Iteration: 1280, Train loss: -0.0815, rewards: 0.5541\n",
      "Iteration: 1290, Train loss: -0.0456, rewards: 0.6188\n",
      "Iteration: 1300, Train loss: -0.0755, rewards: 0.5648\n",
      "Eval:\n",
      "Hits@1: 0.5912, Hits@3: 0.6538, Hits@10: 0.6832, MRR: 0.6248\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: -0.0786, rewards: 0.5957\n",
      "Iteration: 1320, Train loss: -0.0923, rewards: 0.5802\n",
      "Iteration: 1330, Train loss: -0.0346, rewards: 0.6050\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1340, Train loss: -0.0830, rewards: 0.5946\n",
      "Iteration: 1350, Train loss: -0.0708, rewards: 0.5880\n",
      "Iteration: 1360, Train loss: -0.0593, rewards: 0.5761\n",
      "Iteration: 1370, Train loss: -0.0455, rewards: 0.6061\n",
      "Iteration: 1380, Train loss: -0.0645, rewards: 0.6023\n",
      "Iteration: 1390, Train loss: -0.0213, rewards: 0.5798\n",
      "Iteration: 1400, Train loss: -0.0518, rewards: 0.5733\n",
      "Eval:\n",
      "Hits@1: 0.5985, Hits@3: 0.6538, Hits@10: 0.6740, MRR: 0.6282\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.1035, rewards: 0.6045\n",
      "Iteration: 1420, Train loss: -0.0666, rewards: 0.5967\n",
      "Iteration: 1430, Train loss: -0.0343, rewards: 0.5609\n",
      "Iteration: 1440, Train loss: -0.0697, rewards: 0.5705\n",
      "Iteration: 1450, Train loss: -0.0086, rewards: 0.5611\n",
      "Iteration: 1460, Train loss: -0.0434, rewards: 0.5889\n",
      "Iteration: 1470, Train loss: -0.0577, rewards: 0.5911\n",
      "Iteration: 1480, Train loss: -0.0772, rewards: 0.6036\n",
      "Iteration: 1490, Train loss: -0.0437, rewards: 0.5709\n",
      "Iteration: 1500, Train loss: -0.0836, rewards: 0.5764\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6464, Hits@10: 0.6759, MRR: 0.6191\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.0757, rewards: 0.5516\n",
      "Iteration: 1520, Train loss: -0.0337, rewards: 0.6183\n",
      "Iteration: 1530, Train loss: -0.0603, rewards: 0.6123\n",
      "Iteration: 1540, Train loss: -0.0870, rewards: 0.5627\n",
      "Iteration: 1550, Train loss: -0.1400, rewards: 0.5665\n",
      "Iteration: 1560, Train loss: -0.0562, rewards: 0.6116\n",
      "Iteration: 1570, Train loss: -0.0615, rewards: 0.5586\n",
      "Iteration: 1580, Train loss: -0.1021, rewards: 0.6068\n",
      "Iteration: 1590, Train loss: -0.0969, rewards: 0.5927\n",
      "Iteration: 1600, Train loss: -0.0571, rewards: 0.5826\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6501, Hits@10: 0.6796, MRR: 0.6244\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.1140, rewards: 0.5945\n",
      "Iteration: 1620, Train loss: -0.0832, rewards: 0.6098\n",
      "Iteration: 1630, Train loss: -0.0341, rewards: 0.5852\n",
      "Iteration: 1640, Train loss: -0.0375, rewards: 0.5813\n",
      "Iteration: 1650, Train loss: -0.0499, rewards: 0.5837\n",
      "Iteration: 1660, Train loss: -0.0007, rewards: 0.5789\n",
      "Iteration: 1670, Train loss: -0.1082, rewards: 0.6162\n",
      "Iteration: 1680, Train loss: -0.0809, rewards: 0.5770\n",
      "Iteration: 1690, Train loss: -0.1092, rewards: 0.5641\n",
      "Iteration: 1700, Train loss: -0.0989, rewards: 0.5890\n",
      "Eval:\n",
      "Hits@1: 0.5967, Hits@3: 0.6446, Hits@10: 0.6777, MRR: 0.6263\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.0672, rewards: 0.5830\n",
      "Iteration: 1720, Train loss: -0.0569, rewards: 0.6310\n",
      "Iteration: 1730, Train loss: -0.0585, rewards: 0.5902\n",
      "Iteration: 1740, Train loss: -0.0547, rewards: 0.5984\n",
      "Iteration: 1750, Train loss: -0.0793, rewards: 0.6044\n",
      "Iteration: 1760, Train loss: -0.1102, rewards: 0.6111\n",
      "Iteration: 1770, Train loss: -0.0995, rewards: 0.5834\n",
      "Iteration: 1780, Train loss: -0.0662, rewards: 0.6096\n",
      "Iteration: 1790, Train loss: -0.0470, rewards: 0.6070\n",
      "Iteration: 1800, Train loss: -0.0634, rewards: 0.6112\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6464, Hits@10: 0.6796, MRR: 0.6234\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.0547, rewards: 0.5688\n",
      "Iteration: 1820, Train loss: -0.1214, rewards: 0.5945\n",
      "Iteration: 1830, Train loss: -0.1263, rewards: 0.5963\n",
      "Iteration: 1840, Train loss: -0.0059, rewards: 0.6102\n",
      "Iteration: 1850, Train loss: -0.1050, rewards: 0.5972\n",
      "Iteration: 1860, Train loss: -0.0679, rewards: 0.5809\n",
      "Iteration: 1870, Train loss: -0.0732, rewards: 0.5908\n",
      "Iteration: 1880, Train loss: -0.0956, rewards: 0.5921\n",
      "Iteration: 1890, Train loss: -0.1009, rewards: 0.5900\n",
      "Iteration: 1900, Train loss: -0.0926, rewards: 0.5989\n",
      "Eval:\n",
      "Hits@1: 0.6022, Hits@3: 0.6519, Hits@10: 0.6759, MRR: 0.6302\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.0928, rewards: 0.5950\n",
      "Iteration: 1920, Train loss: -0.0325, rewards: 0.5725\n",
      "Iteration: 1930, Train loss: -0.0479, rewards: 0.6142\n",
      "Iteration: 1940, Train loss: -0.1033, rewards: 0.5973\n",
      "Iteration: 1950, Train loss: 0.0187, rewards: 0.5909\n",
      "Iteration: 1960, Train loss: -0.0699, rewards: 0.5989\n",
      "Iteration: 1970, Train loss: -0.0369, rewards: 0.6379\n",
      "Iteration: 1980, Train loss: -0.0569, rewards: 0.6282\n",
      "Iteration: 1990, Train loss: -0.0908, rewards: 0.5949\n",
      "Iteration: 2000, Train loss: -0.0451, rewards: 0.5894\n",
      "Eval:\n",
      "Hits@1: 0.5930, Hits@3: 0.6501, Hits@10: 0.6796, MRR: 0.6244\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(options)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e48d79b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.6906, Hits@3: 0.7867, Hits@10: 0.8304, MRR: 0.7432\n"
     ]
    }
   ],
   "source": [
    "trainer.agent.load_state_dict(torch.load(options['model_dir'] + 'agent.ckpt'))\n",
    "trainer.agent.eval()\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "test_results = trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a8c25925",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.6906, Hits@3: 0.7867, Hits@10: 0.8304, MRR: 0.7432\n"
     ]
    }
   ],
   "source": [
    "print(test_results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
