{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ce63a777",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-12e7197d-4e01-5bc8-aa76-2be6e3a55125\"\n",
    "\n",
    "from model.ours2 import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "398c7676",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = {}\n",
    "\n",
    "#basic setting\n",
    "options['use_cuda'] = True\n",
    "options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/vocab/'\n",
    "options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/'\n",
    "options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "options['model_dir'] = './outputs_WN18RR-best/'\n",
    "options['output_dir'] = './outputs_WN18RR-best/'\n",
    "\n",
    "#agent setting\n",
    "options['pretrained_embeddings_relation'] = {}\n",
    "options['pretrained_embeddings_entity'] = {}\n",
    "options['embedding_size'] = 50\n",
    "options['hidden_size'] = 200\n",
    "options['use_entity_embeddings'] = 1\n",
    "options['train_entity_embeddings'] = 0\n",
    "options['train_relation_embeddings'] = 1\n",
    "options['path_length'] = 3\n",
    "options['LSTM_layers'] = 1\n",
    "options['max_num_actions'] = 40\n",
    "options['gnn_layer'] = 1\n",
    "\n",
    "#hyperparameters\n",
    "options['test_rollouts'] = 40\n",
    "options['num_rollouts'] = 20\n",
    "options['batch_size'] = 32\n",
    "options['eval_batch_size'] = 32\n",
    "options['beta'] = 0.05\n",
    "options['Lambda'] = 0.05\n",
    "options['gamma'] = 1\n",
    "options['positive_reward'] = 1\n",
    "options['negative_reward'] = 0\n",
    "options['learning_rate'] = 0.0001\n",
    "options['grad_clip_norm'] = 100\n",
    "options['eval_every'] = 100\n",
    "options['total_iterations'] = 2000*(64/options['batch_size'])\n",
    "options['pool'] = 'max'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d9b9fda",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/experiments/model/ours2.py:333: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.1781, rewards: 0.1203\n",
      "Iteration: 20, Train loss: -0.2795, rewards: 0.2356\n",
      "Iteration: 30, Train loss: -0.2227, rewards: 0.2859\n",
      "Iteration: 40, Train loss: -0.3522, rewards: 0.3130\n",
      "Iteration: 50, Train loss: -0.3698, rewards: 0.2830\n",
      "Iteration: 60, Train loss: -0.4164, rewards: 0.3203\n",
      "Iteration: 70, Train loss: -0.4718, rewards: 0.3759\n",
      "Iteration: 80, Train loss: -0.3527, rewards: 0.3353\n",
      "Iteration: 90, Train loss: -0.3036, rewards: 0.3575\n",
      "Iteration: 100, Train loss: -0.3045, rewards: 0.3642\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/Research/GraphRL/experiments/model/ours2.py:635: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.3810, Hits@3: 0.4403, Hits@10: 0.4908, MRR: 0.4175\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.3253, rewards: 0.3231\n",
      "Iteration: 120, Train loss: -0.3451, rewards: 0.4078\n",
      "Iteration: 130, Train loss: -0.3685, rewards: 0.3817\n",
      "Iteration: 140, Train loss: -0.3671, rewards: 0.3917\n",
      "Iteration: 150, Train loss: -0.3156, rewards: 0.3631\n",
      "Iteration: 160, Train loss: -0.4034, rewards: 0.3825\n",
      "Iteration: 170, Train loss: -0.4553, rewards: 0.3686\n",
      "Iteration: 180, Train loss: -0.4574, rewards: 0.3944\n",
      "Iteration: 190, Train loss: -0.4464, rewards: 0.3477\n",
      "Iteration: 200, Train loss: -0.4058, rewards: 0.3900\n",
      "Eval:\n",
      "Hits@1: 0.3820, Hits@3: 0.4512, Hits@10: 0.5040, MRR: 0.4241\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.4062, rewards: 0.3881\n",
      "Iteration: 220, Train loss: -0.3546, rewards: 0.3875\n",
      "Iteration: 230, Train loss: -0.3925, rewards: 0.4345\n",
      "Iteration: 240, Train loss: -0.3580, rewards: 0.3872\n",
      "Iteration: 250, Train loss: -0.3991, rewards: 0.3792\n",
      "Iteration: 260, Train loss: -0.4225, rewards: 0.3613\n",
      "Iteration: 270, Train loss: -0.4372, rewards: 0.3683\n",
      "Iteration: 280, Train loss: -0.4344, rewards: 0.3611\n",
      "Iteration: 290, Train loss: -0.4836, rewards: 0.3631\n",
      "Iteration: 300, Train loss: -0.4783, rewards: 0.3852\n",
      "Eval:\n",
      "Hits@1: 0.3958, Hits@3: 0.4529, Hits@10: 0.5053, MRR: 0.4324\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.4495, rewards: 0.4205\n",
      "Iteration: 320, Train loss: -0.4311, rewards: 0.4103\n",
      "Iteration: 330, Train loss: -0.4533, rewards: 0.4406\n",
      "Iteration: 340, Train loss: -0.3737, rewards: 0.4081\n",
      "Iteration: 350, Train loss: -0.4074, rewards: 0.4498\n",
      "Iteration: 360, Train loss: -0.4594, rewards: 0.3583\n",
      "Iteration: 370, Train loss: -0.5129, rewards: 0.3837\n",
      "Iteration: 380, Train loss: -0.4620, rewards: 0.3956\n",
      "Iteration: 390, Train loss: -0.4589, rewards: 0.3670\n",
      "Iteration: 400, Train loss: -0.4118, rewards: 0.4288\n",
      "Eval:\n",
      "Hits@1: 0.3771, Hits@3: 0.4502, Hits@10: 0.5030, MRR: 0.4220\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.4640, rewards: 0.3895\n",
      "Iteration: 420, Train loss: -0.4741, rewards: 0.4342\n",
      "Iteration: 430, Train loss: -0.4631, rewards: 0.3602\n",
      "Iteration: 440, Train loss: -0.4746, rewards: 0.4200\n",
      "Iteration: 450, Train loss: -0.4012, rewards: 0.4408\n",
      "Iteration: 460, Train loss: -0.4620, rewards: 0.3866\n",
      "Iteration: 470, Train loss: -0.4575, rewards: 0.3927\n",
      "Iteration: 480, Train loss: -0.4301, rewards: 0.3812\n",
      "Iteration: 490, Train loss: -0.4529, rewards: 0.4330\n",
      "Iteration: 500, Train loss: -0.4695, rewards: 0.4119\n",
      "Eval:\n",
      "Hits@1: 0.3958, Hits@3: 0.4641, Hits@10: 0.5138, MRR: 0.4361\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.4748, rewards: 0.4011\n",
      "Iteration: 520, Train loss: -0.4405, rewards: 0.4800\n",
      "Iteration: 530, Train loss: -0.4473, rewards: 0.3905\n",
      "Iteration: 540, Train loss: -0.4668, rewards: 0.4650\n",
      "Iteration: 550, Train loss: -0.4527, rewards: 0.3967\n",
      "Iteration: 560, Train loss: -0.4602, rewards: 0.4547\n",
      "Iteration: 570, Train loss: -0.4872, rewards: 0.4261\n",
      "Iteration: 580, Train loss: -0.4904, rewards: 0.4159\n",
      "Iteration: 590, Train loss: -0.3879, rewards: 0.3948\n",
      "Iteration: 600, Train loss: -0.3624, rewards: 0.4502\n",
      "Eval:\n",
      "Hits@1: 0.4047, Hits@3: 0.4618, Hits@10: 0.5043, MRR: 0.4389\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.3698, rewards: 0.3564\n",
      "Iteration: 620, Train loss: -0.3997, rewards: 0.4261\n",
      "Iteration: 630, Train loss: -0.4478, rewards: 0.4047\n",
      "Iteration: 640, Train loss: -0.4194, rewards: 0.3681\n",
      "Iteration: 650, Train loss: -0.4171, rewards: 0.4053\n",
      "Iteration: 660, Train loss: -0.4844, rewards: 0.4019\n",
      "Iteration: 670, Train loss: -0.3845, rewards: 0.4420\n",
      "Iteration: 680, Train loss: -0.3961, rewards: 0.3672\n",
      "Iteration: 690, Train loss: -0.3845, rewards: 0.4247\n",
      "Iteration: 700, Train loss: -0.4456, rewards: 0.3747\n",
      "Eval:\n",
      "Hits@1: 0.3912, Hits@3: 0.4562, Hits@10: 0.5049, MRR: 0.4301\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.3797, rewards: 0.3912\n",
      "Iteration: 720, Train loss: -0.4570, rewards: 0.4553\n",
      "Iteration: 730, Train loss: -0.4391, rewards: 0.4642\n",
      "Iteration: 740, Train loss: -0.3804, rewards: 0.4184\n",
      "Iteration: 750, Train loss: -0.3997, rewards: 0.4030\n",
      "Iteration: 760, Train loss: -0.4369, rewards: 0.3916\n",
      "Iteration: 770, Train loss: -0.4038, rewards: 0.4072\n",
      "Iteration: 780, Train loss: -0.4424, rewards: 0.3853\n",
      "Iteration: 790, Train loss: -0.4819, rewards: 0.4625\n",
      "Iteration: 800, Train loss: -0.4505, rewards: 0.3625\n",
      "Eval:\n",
      "Hits@1: 0.3929, Hits@3: 0.4674, Hits@10: 0.5056, MRR: 0.4340\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.4578, rewards: 0.4047\n",
      "Iteration: 820, Train loss: -0.4781, rewards: 0.3847\n",
      "Iteration: 830, Train loss: -0.4315, rewards: 0.4102\n",
      "Iteration: 840, Train loss: -0.4593, rewards: 0.3905\n",
      "Iteration: 850, Train loss: -0.4478, rewards: 0.4358\n",
      "Iteration: 860, Train loss: -0.4196, rewards: 0.4050\n",
      "Iteration: 870, Train loss: -0.4745, rewards: 0.4756\n",
      "Iteration: 880, Train loss: -0.4440, rewards: 0.4081\n",
      "Iteration: 890, Train loss: -0.4516, rewards: 0.3837\n",
      "Iteration: 900, Train loss: -0.3982, rewards: 0.4233\n",
      "Eval:\n",
      "Hits@1: 0.4249, Hits@3: 0.4674, Hits@10: 0.5073, MRR: 0.4522\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.4237, rewards: 0.4181\n",
      "Iteration: 920, Train loss: -0.4892, rewards: 0.4036\n",
      "Iteration: 930, Train loss: -0.4432, rewards: 0.4119\n",
      "Iteration: 940, Train loss: -0.4404, rewards: 0.3839\n",
      "Iteration: 950, Train loss: -0.4457, rewards: 0.4325\n",
      "Iteration: 960, Train loss: -0.3960, rewards: 0.3836\n",
      "Iteration: 970, Train loss: -0.4234, rewards: 0.3881\n",
      "Iteration: 980, Train loss: -0.4679, rewards: 0.4203\n",
      "Iteration: 990, Train loss: -0.3785, rewards: 0.3675\n",
      "Iteration: 1000, Train loss: -0.3146, rewards: 0.4172\n",
      "Eval:\n",
      "Hits@1: 0.4242, Hits@3: 0.4703, Hits@10: 0.5122, MRR: 0.4529\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.3585, rewards: 0.3642\n",
      "Iteration: 1020, Train loss: -0.3827, rewards: 0.3856\n",
      "Iteration: 1030, Train loss: -0.3961, rewards: 0.4241\n",
      "Iteration: 1040, Train loss: -0.4174, rewards: 0.4148\n",
      "Iteration: 1050, Train loss: -0.3203, rewards: 0.3911\n",
      "Iteration: 1060, Train loss: -0.3916, rewards: 0.4373\n",
      "Iteration: 1070, Train loss: -0.3598, rewards: 0.3875\n",
      "Iteration: 1080, Train loss: -0.3362, rewards: 0.3805\n",
      "Iteration: 1090, Train loss: -0.3027, rewards: 0.3575\n",
      "Iteration: 1100, Train loss: -0.3258, rewards: 0.3794\n",
      "Eval:\n",
      "Hits@1: 0.4255, Hits@3: 0.4763, Hits@10: 0.5194, MRR: 0.4566\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.4033, rewards: 0.3966\n",
      "Iteration: 1120, Train loss: -0.3911, rewards: 0.4278\n",
      "Iteration: 1130, Train loss: -0.3602, rewards: 0.4030\n",
      "Iteration: 1140, Train loss: -0.3698, rewards: 0.3922\n",
      "Iteration: 1150, Train loss: -0.3787, rewards: 0.4263\n",
      "Iteration: 1160, Train loss: -0.3592, rewards: 0.4163\n",
      "Iteration: 1170, Train loss: -0.3868, rewards: 0.4083\n",
      "Iteration: 1180, Train loss: -0.3713, rewards: 0.4081\n",
      "Iteration: 1190, Train loss: -0.3656, rewards: 0.4180\n",
      "Iteration: 1200, Train loss: -0.3698, rewards: 0.3647\n",
      "Eval:\n",
      "Hits@1: 0.4028, Hits@3: 0.4634, Hits@10: 0.4984, MRR: 0.4377\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.3715, rewards: 0.4092\n",
      "Iteration: 1220, Train loss: -0.3605, rewards: 0.4277\n",
      "Iteration: 1230, Train loss: -0.3565, rewards: 0.4134\n",
      "Iteration: 1240, Train loss: -0.3655, rewards: 0.3852\n",
      "Iteration: 1250, Train loss: -0.3840, rewards: 0.4414\n",
      "Iteration: 1260, Train loss: -0.4027, rewards: 0.3625\n",
      "Iteration: 1270, Train loss: -0.4413, rewards: 0.4283\n",
      "Iteration: 1280, Train loss: -0.4163, rewards: 0.4522\n",
      "Iteration: 1290, Train loss: -0.4057, rewards: 0.4028\n",
      "Iteration: 1300, Train loss: -0.3573, rewards: 0.3944\n",
      "Eval:\n",
      "Hits@1: 0.3523, Hits@3: 0.4601, Hits@10: 0.5102, MRR: 0.4112\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: -0.3542, rewards: 0.4284\n",
      "Iteration: 1320, Train loss: -0.3559, rewards: 0.4234\n",
      "Iteration: 1330, Train loss: -0.3751, rewards: 0.4111\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1340, Train loss: -0.3724, rewards: 0.3845\n",
      "Iteration: 1350, Train loss: -0.3874, rewards: 0.4141\n",
      "Iteration: 1360, Train loss: -0.3205, rewards: 0.4363\n",
      "Iteration: 1370, Train loss: -0.3582, rewards: 0.4064\n",
      "Iteration: 1380, Train loss: -0.3341, rewards: 0.4523\n",
      "Iteration: 1390, Train loss: -0.4709, rewards: 0.4645\n",
      "Iteration: 1400, Train loss: -0.4124, rewards: 0.4166\n",
      "Eval:\n",
      "Hits@1: 0.4153, Hits@3: 0.4651, Hits@10: 0.5089, MRR: 0.4458\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.4215, rewards: 0.4637\n",
      "Iteration: 1420, Train loss: -0.3900, rewards: 0.4111\n",
      "Iteration: 1430, Train loss: -0.3821, rewards: 0.4539\n",
      "Iteration: 1440, Train loss: -0.4185, rewards: 0.4030\n",
      "Iteration: 1450, Train loss: -0.4730, rewards: 0.4591\n",
      "Iteration: 1460, Train loss: -0.4835, rewards: 0.4141\n",
      "Iteration: 1470, Train loss: -0.3932, rewards: 0.4300\n",
      "Iteration: 1480, Train loss: -0.3576, rewards: 0.4227\n",
      "Iteration: 1490, Train loss: -0.4052, rewards: 0.4691\n",
      "Iteration: 1500, Train loss: -0.3571, rewards: 0.3834\n",
      "Eval:\n",
      "Hits@1: 0.4113, Hits@3: 0.4690, Hits@10: 0.5129, MRR: 0.4467\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.3439, rewards: 0.4263\n",
      "Iteration: 1520, Train loss: -0.3818, rewards: 0.4294\n",
      "Iteration: 1530, Train loss: -0.3994, rewards: 0.4136\n",
      "Iteration: 1540, Train loss: -0.3742, rewards: 0.3936\n",
      "Iteration: 1550, Train loss: -0.3998, rewards: 0.4041\n",
      "Iteration: 1560, Train loss: -0.3212, rewards: 0.4425\n",
      "Iteration: 1570, Train loss: -0.3605, rewards: 0.4309\n",
      "Iteration: 1580, Train loss: -0.3956, rewards: 0.3889\n",
      "Iteration: 1590, Train loss: -0.3522, rewards: 0.3936\n",
      "Iteration: 1600, Train loss: -0.4173, rewards: 0.4419\n",
      "Eval:\n",
      "Hits@1: 0.3972, Hits@3: 0.4684, Hits@10: 0.5125, MRR: 0.4393\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.4179, rewards: 0.4086\n",
      "Iteration: 1620, Train loss: -0.4268, rewards: 0.4125\n",
      "Iteration: 1630, Train loss: -0.4048, rewards: 0.4064\n",
      "Iteration: 1640, Train loss: -0.4227, rewards: 0.4114\n",
      "Iteration: 1650, Train loss: -0.3681, rewards: 0.4300\n",
      "Iteration: 1660, Train loss: -0.3364, rewards: 0.3820\n",
      "Iteration: 1670, Train loss: -0.3467, rewards: 0.4213\n",
      "Iteration: 1680, Train loss: -0.3875, rewards: 0.4477\n",
      "Iteration: 1690, Train loss: -0.4060, rewards: 0.3905\n",
      "Iteration: 1700, Train loss: -0.3283, rewards: 0.4297\n",
      "Eval:\n",
      "Hits@1: 0.4222, Hits@3: 0.4730, Hits@10: 0.5158, MRR: 0.4540\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.3815, rewards: 0.4016\n",
      "Iteration: 1720, Train loss: -0.3664, rewards: 0.4098\n",
      "Iteration: 1730, Train loss: -0.4007, rewards: 0.4453\n",
      "Iteration: 1740, Train loss: -0.3338, rewards: 0.4320\n",
      "Iteration: 1750, Train loss: -0.3603, rewards: 0.4155\n",
      "Iteration: 1760, Train loss: -0.4053, rewards: 0.4537\n",
      "Iteration: 1770, Train loss: -0.3466, rewards: 0.4734\n",
      "Iteration: 1780, Train loss: -0.3152, rewards: 0.3805\n",
      "Iteration: 1790, Train loss: -0.3189, rewards: 0.4070\n",
      "Iteration: 1800, Train loss: -0.3135, rewards: 0.4203\n",
      "Eval:\n",
      "Hits@1: 0.4166, Hits@3: 0.4710, Hits@10: 0.5175, MRR: 0.4510\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.3913, rewards: 0.4595\n",
      "Iteration: 1820, Train loss: -0.3299, rewards: 0.4072\n",
      "Iteration: 1830, Train loss: -0.3332, rewards: 0.4306\n",
      "Iteration: 1840, Train loss: -0.3792, rewards: 0.4014\n",
      "Iteration: 1850, Train loss: -0.3149, rewards: 0.4375\n",
      "Iteration: 1860, Train loss: -0.3318, rewards: 0.4250\n",
      "Iteration: 1870, Train loss: -0.3509, rewards: 0.4319\n",
      "Iteration: 1880, Train loss: -0.2974, rewards: 0.3583\n",
      "Iteration: 1890, Train loss: -0.2811, rewards: 0.4527\n",
      "Iteration: 1900, Train loss: -0.3338, rewards: 0.4173\n",
      "Eval:\n",
      "Hits@1: 0.4245, Hits@3: 0.4693, Hits@10: 0.5059, MRR: 0.4523\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.3401, rewards: 0.4587\n",
      "Iteration: 1920, Train loss: -0.3734, rewards: 0.4194\n",
      "Iteration: 1930, Train loss: -0.3171, rewards: 0.4186\n",
      "Iteration: 1940, Train loss: -0.3636, rewards: 0.4291\n",
      "Iteration: 1950, Train loss: -0.3252, rewards: 0.4334\n",
      "Iteration: 1960, Train loss: -0.3454, rewards: 0.4069\n",
      "Iteration: 1970, Train loss: -0.3281, rewards: 0.4466\n",
      "Iteration: 1980, Train loss: -0.4141, rewards: 0.3905\n",
      "Iteration: 1990, Train loss: -0.4074, rewards: 0.4497\n",
      "Iteration: 2000, Train loss: -0.3933, rewards: 0.4456\n",
      "Eval:\n",
      "Hits@1: 0.4239, Hits@3: 0.4717, Hits@10: 0.5145, MRR: 0.4537\n",
      "------------------------------------------------------------\n",
      "Iteration: 2010, Train loss: -0.2949, rewards: 0.4528\n",
      "Iteration: 2020, Train loss: -0.3478, rewards: 0.4377\n",
      "Iteration: 2030, Train loss: -0.3732, rewards: 0.4858\n",
      "Iteration: 2040, Train loss: -0.2800, rewards: 0.4270\n",
      "Iteration: 2050, Train loss: -0.3122, rewards: 0.4286\n",
      "Iteration: 2060, Train loss: -0.3525, rewards: 0.4230\n",
      "Iteration: 2070, Train loss: -0.3585, rewards: 0.4139\n",
      "Iteration: 2080, Train loss: -0.3428, rewards: 0.4308\n",
      "Iteration: 2090, Train loss: -0.3862, rewards: 0.4095\n",
      "Iteration: 2100, Train loss: -0.3583, rewards: 0.4428\n",
      "Eval:\n",
      "Hits@1: 0.4225, Hits@3: 0.4730, Hits@10: 0.5115, MRR: 0.4535\n",
      "------------------------------------------------------------\n",
      "Iteration: 2110, Train loss: -0.3140, rewards: 0.4273\n",
      "Iteration: 2120, Train loss: -0.2972, rewards: 0.3561\n",
      "Iteration: 2130, Train loss: -0.3220, rewards: 0.3887\n",
      "Iteration: 2140, Train loss: -0.3375, rewards: 0.4056\n",
      "Iteration: 2150, Train loss: -0.3245, rewards: 0.4297\n",
      "Iteration: 2160, Train loss: -0.4037, rewards: 0.4175\n",
      "Iteration: 2170, Train loss: -0.4181, rewards: 0.4411\n",
      "Iteration: 2180, Train loss: -0.3647, rewards: 0.4186\n",
      "Iteration: 2190, Train loss: -0.2792, rewards: 0.4356\n",
      "Iteration: 2200, Train loss: -0.4065, rewards: 0.4288\n",
      "Eval:\n",
      "Hits@1: 0.3724, Hits@3: 0.4684, Hits@10: 0.5112, MRR: 0.4265\n",
      "------------------------------------------------------------\n",
      "Iteration: 2210, Train loss: -0.3231, rewards: 0.4306\n",
      "Iteration: 2220, Train loss: -0.3268, rewards: 0.4455\n",
      "Iteration: 2230, Train loss: -0.3025, rewards: 0.3633\n",
      "Iteration: 2240, Train loss: -0.3690, rewards: 0.4450\n",
      "Iteration: 2250, Train loss: -0.3459, rewards: 0.4091\n",
      "Iteration: 2260, Train loss: -0.3259, rewards: 0.4278\n",
      "Iteration: 2270, Train loss: -0.3136, rewards: 0.4058\n",
      "Iteration: 2280, Train loss: -0.3644, rewards: 0.4598\n",
      "Iteration: 2290, Train loss: -0.3258, rewards: 0.4281\n",
      "Iteration: 2300, Train loss: -0.2778, rewards: 0.3481\n",
      "Eval:\n",
      "Hits@1: 0.4094, Hits@3: 0.4667, Hits@10: 0.5158, MRR: 0.4457\n",
      "------------------------------------------------------------\n",
      "Iteration: 2310, Train loss: -0.3960, rewards: 0.4152\n",
      "Iteration: 2320, Train loss: -0.3227, rewards: 0.4256\n",
      "Iteration: 2330, Train loss: -0.3081, rewards: 0.3881\n",
      "Iteration: 2340, Train loss: -0.2939, rewards: 0.4130\n",
      "Iteration: 2350, Train loss: -0.3535, rewards: 0.4098\n",
      "Iteration: 2360, Train loss: -0.3593, rewards: 0.4266\n",
      "Iteration: 2370, Train loss: -0.3570, rewards: 0.4378\n",
      "Iteration: 2380, Train loss: -0.3712, rewards: 0.4672\n",
      "Iteration: 2390, Train loss: -0.3328, rewards: 0.3970\n",
      "Iteration: 2400, Train loss: -0.3154, rewards: 0.3992\n",
      "Eval:\n",
      "Hits@1: 0.4268, Hits@3: 0.4730, Hits@10: 0.5145, MRR: 0.4556\n",
      "------------------------------------------------------------\n",
      "Iteration: 2410, Train loss: -0.3124, rewards: 0.4470\n",
      "Iteration: 2420, Train loss: -0.3411, rewards: 0.4436\n",
      "Iteration: 2430, Train loss: -0.3488, rewards: 0.4559\n",
      "Iteration: 2440, Train loss: -0.2706, rewards: 0.4319\n",
      "Iteration: 2450, Train loss: -0.2809, rewards: 0.4259\n",
      "Iteration: 2460, Train loss: -0.2308, rewards: 0.4194\n",
      "Iteration: 2470, Train loss: -0.3227, rewards: 0.4189\n",
      "Iteration: 2480, Train loss: -0.3302, rewards: 0.4191\n",
      "Iteration: 2490, Train loss: -0.2687, rewards: 0.4053\n",
      "Iteration: 2500, Train loss: -0.3372, rewards: 0.4078\n",
      "Eval:\n",
      "Hits@1: 0.4301, Hits@3: 0.4776, Hits@10: 0.5181, MRR: 0.4599\n",
      "------------------------------------------------------------\n",
      "Iteration: 2510, Train loss: -0.3240, rewards: 0.4566\n",
      "Iteration: 2520, Train loss: -0.3274, rewards: 0.3986\n",
      "Iteration: 2530, Train loss: -0.3728, rewards: 0.4264\n",
      "Iteration: 2540, Train loss: -0.3524, rewards: 0.4495\n",
      "Iteration: 2550, Train loss: -0.3787, rewards: 0.4062\n",
      "Iteration: 2560, Train loss: -0.3535, rewards: 0.3766\n",
      "Iteration: 2570, Train loss: -0.3014, rewards: 0.4017\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 2580, Train loss: -0.4160, rewards: 0.3944\n",
      "Iteration: 2590, Train loss: -0.3599, rewards: 0.4544\n",
      "Iteration: 2600, Train loss: -0.3525, rewards: 0.4450\n",
      "Eval:\n",
      "Hits@1: 0.4318, Hits@3: 0.4782, Hits@10: 0.5241, MRR: 0.4617\n",
      "------------------------------------------------------------\n",
      "Iteration: 2610, Train loss: -0.3740, rewards: 0.4339\n",
      "Iteration: 2620, Train loss: -0.3655, rewards: 0.4020\n",
      "Iteration: 2630, Train loss: -0.3553, rewards: 0.4059\n",
      "Iteration: 2640, Train loss: -0.3067, rewards: 0.4587\n",
      "Iteration: 2650, Train loss: -0.2956, rewards: 0.3894\n",
      "Iteration: 2660, Train loss: -0.3523, rewards: 0.4422\n",
      "Iteration: 2670, Train loss: -0.3770, rewards: 0.4142\n",
      "Iteration: 2680, Train loss: -0.3538, rewards: 0.3950\n",
      "Iteration: 2690, Train loss: -0.3355, rewards: 0.4000\n",
      "Iteration: 2700, Train loss: -0.4086, rewards: 0.4378\n",
      "Eval:\n",
      "Hits@1: 0.4222, Hits@3: 0.4740, Hits@10: 0.5152, MRR: 0.4538\n",
      "------------------------------------------------------------\n",
      "Iteration: 2710, Train loss: -0.3314, rewards: 0.4122\n",
      "Iteration: 2720, Train loss: -0.3632, rewards: 0.4281\n",
      "Iteration: 2730, Train loss: -0.3689, rewards: 0.4370\n",
      "Iteration: 2740, Train loss: -0.3369, rewards: 0.4572\n",
      "Iteration: 2750, Train loss: -0.3625, rewards: 0.4270\n",
      "Iteration: 2760, Train loss: -0.3124, rewards: 0.3948\n",
      "Iteration: 2770, Train loss: -0.3521, rewards: 0.4297\n",
      "Iteration: 2780, Train loss: -0.3266, rewards: 0.4275\n",
      "Iteration: 2790, Train loss: -0.3478, rewards: 0.4172\n",
      "Iteration: 2800, Train loss: -0.3916, rewards: 0.4283\n",
      "Eval:\n",
      "Hits@1: 0.3774, Hits@3: 0.4697, Hits@10: 0.5175, MRR: 0.4280\n",
      "------------------------------------------------------------\n",
      "Iteration: 2810, Train loss: -0.3411, rewards: 0.4133\n",
      "Iteration: 2820, Train loss: -0.3384, rewards: 0.3650\n",
      "Iteration: 2830, Train loss: -0.3225, rewards: 0.4437\n",
      "Iteration: 2840, Train loss: -0.3590, rewards: 0.4005\n",
      "Iteration: 2850, Train loss: -0.3826, rewards: 0.4258\n",
      "Iteration: 2860, Train loss: -0.3740, rewards: 0.4547\n",
      "Iteration: 2870, Train loss: -0.3452, rewards: 0.4005\n",
      "Iteration: 2880, Train loss: -0.3641, rewards: 0.4223\n",
      "Iteration: 2890, Train loss: -0.3962, rewards: 0.3906\n",
      "Iteration: 2900, Train loss: -0.3769, rewards: 0.4069\n",
      "Eval:\n",
      "Hits@1: 0.4295, Hits@3: 0.4792, Hits@10: 0.5185, MRR: 0.4581\n",
      "------------------------------------------------------------\n",
      "Iteration: 2910, Train loss: -0.3848, rewards: 0.4300\n",
      "Iteration: 2920, Train loss: -0.3574, rewards: 0.3900\n",
      "Iteration: 2930, Train loss: -0.3772, rewards: 0.4020\n",
      "Iteration: 2940, Train loss: -0.4014, rewards: 0.3673\n",
      "Iteration: 2950, Train loss: -0.4038, rewards: 0.4220\n",
      "Iteration: 2960, Train loss: -0.4393, rewards: 0.4322\n",
      "Iteration: 2970, Train loss: -0.3736, rewards: 0.4328\n",
      "Iteration: 2980, Train loss: -0.3755, rewards: 0.4123\n",
      "Iteration: 2990, Train loss: -0.3600, rewards: 0.3962\n",
      "Iteration: 3000, Train loss: -0.2903, rewards: 0.4564\n",
      "Eval:\n",
      "Hits@1: 0.3863, Hits@3: 0.4697, Hits@10: 0.5089, MRR: 0.4326\n",
      "------------------------------------------------------------\n",
      "Iteration: 3010, Train loss: -0.3390, rewards: 0.4358\n",
      "Iteration: 3020, Train loss: -0.3492, rewards: 0.3891\n",
      "Iteration: 3030, Train loss: -0.3593, rewards: 0.4339\n",
      "Iteration: 3040, Train loss: -0.3455, rewards: 0.4828\n",
      "Iteration: 3050, Train loss: -0.3371, rewards: 0.4334\n",
      "Iteration: 3060, Train loss: -0.3269, rewards: 0.4270\n",
      "Iteration: 3070, Train loss: -0.3399, rewards: 0.4436\n",
      "Iteration: 3080, Train loss: -0.3660, rewards: 0.4202\n",
      "Iteration: 3090, Train loss: -0.4015, rewards: 0.4017\n",
      "Iteration: 3100, Train loss: -0.4210, rewards: 0.3917\n",
      "Eval:\n",
      "Hits@1: 0.4136, Hits@3: 0.4779, Hits@10: 0.5250, MRR: 0.4512\n",
      "------------------------------------------------------------\n",
      "Iteration: 3110, Train loss: -0.3317, rewards: 0.3805\n",
      "Iteration: 3120, Train loss: -0.3930, rewards: 0.4519\n",
      "Iteration: 3130, Train loss: -0.4268, rewards: 0.4369\n",
      "Iteration: 3140, Train loss: -0.3865, rewards: 0.4113\n",
      "Iteration: 3150, Train loss: -0.4269, rewards: 0.4487\n",
      "Iteration: 3160, Train loss: -0.4524, rewards: 0.4109\n",
      "Iteration: 3170, Train loss: -0.3720, rewards: 0.3895\n",
      "Iteration: 3180, Train loss: -0.3901, rewards: 0.3973\n",
      "Iteration: 3190, Train loss: -0.3887, rewards: 0.4653\n",
      "Iteration: 3200, Train loss: -0.3642, rewards: 0.4433\n",
      "Eval:\n",
      "Hits@1: 0.4272, Hits@3: 0.4825, Hits@10: 0.5241, MRR: 0.4599\n",
      "------------------------------------------------------------\n",
      "Iteration: 3210, Train loss: -0.3775, rewards: 0.4311\n",
      "Iteration: 3220, Train loss: -0.3812, rewards: 0.4375\n",
      "Iteration: 3230, Train loss: -0.3755, rewards: 0.3925\n",
      "Iteration: 3240, Train loss: -0.4341, rewards: 0.4106\n",
      "Iteration: 3250, Train loss: -0.4097, rewards: 0.4148\n",
      "Iteration: 3260, Train loss: -0.4341, rewards: 0.4286\n",
      "Iteration: 3270, Train loss: -0.3509, rewards: 0.4209\n",
      "Iteration: 3280, Train loss: -0.4225, rewards: 0.4325\n",
      "Iteration: 3290, Train loss: -0.4074, rewards: 0.4384\n",
      "Iteration: 3300, Train loss: -0.3560, rewards: 0.4070\n",
      "Eval:\n",
      "Hits@1: 0.4328, Hits@3: 0.4753, Hits@10: 0.5165, MRR: 0.4599\n",
      "------------------------------------------------------------\n",
      "Iteration: 3310, Train loss: -0.3704, rewards: 0.4525\n",
      "Iteration: 3320, Train loss: -0.3559, rewards: 0.4141\n",
      "Iteration: 3330, Train loss: -0.3219, rewards: 0.4550\n",
      "Iteration: 3340, Train loss: -0.2901, rewards: 0.3916\n",
      "Iteration: 3350, Train loss: -0.4015, rewards: 0.4122\n",
      "Iteration: 3360, Train loss: -0.3631, rewards: 0.3964\n",
      "Iteration: 3370, Train loss: -0.3627, rewards: 0.4519\n",
      "Iteration: 3380, Train loss: -0.3755, rewards: 0.4392\n",
      "Iteration: 3390, Train loss: -0.4077, rewards: 0.4083\n",
      "Iteration: 3400, Train loss: -0.3498, rewards: 0.3770\n",
      "Eval:\n",
      "Hits@1: 0.4338, Hits@3: 0.4743, Hits@10: 0.5152, MRR: 0.4597\n",
      "------------------------------------------------------------\n",
      "Iteration: 3410, Train loss: -0.4162, rewards: 0.4364\n",
      "Iteration: 3420, Train loss: -0.3429, rewards: 0.4091\n",
      "Iteration: 3430, Train loss: -0.4016, rewards: 0.4163\n",
      "Iteration: 3440, Train loss: -0.3834, rewards: 0.3959\n",
      "Iteration: 3450, Train loss: -0.4036, rewards: 0.3823\n",
      "Iteration: 3460, Train loss: -0.3598, rewards: 0.3953\n",
      "Iteration: 3470, Train loss: -0.3309, rewards: 0.4364\n",
      "Iteration: 3480, Train loss: -0.3265, rewards: 0.4203\n",
      "Iteration: 3490, Train loss: -0.3048, rewards: 0.3970\n",
      "Iteration: 3500, Train loss: -0.2868, rewards: 0.4811\n",
      "Eval:\n",
      "Hits@1: 0.4249, Hits@3: 0.4766, Hits@10: 0.5171, MRR: 0.4561\n",
      "------------------------------------------------------------\n",
      "Iteration: 3510, Train loss: -0.3379, rewards: 0.4467\n",
      "Iteration: 3520, Train loss: -0.3097, rewards: 0.4088\n",
      "Iteration: 3530, Train loss: -0.3680, rewards: 0.3981\n",
      "Iteration: 3540, Train loss: -0.3319, rewards: 0.4347\n",
      "Iteration: 3550, Train loss: -0.3676, rewards: 0.4597\n",
      "Iteration: 3560, Train loss: -0.3657, rewards: 0.4391\n",
      "Iteration: 3570, Train loss: -0.3811, rewards: 0.4516\n",
      "Iteration: 3580, Train loss: -0.2924, rewards: 0.3897\n",
      "Iteration: 3590, Train loss: -0.3677, rewards: 0.4080\n",
      "Iteration: 3600, Train loss: -0.3014, rewards: 0.3967\n",
      "Eval:\n",
      "Hits@1: 0.4344, Hits@3: 0.4763, Hits@10: 0.5175, MRR: 0.4611\n",
      "------------------------------------------------------------\n",
      "Iteration: 3610, Train loss: -0.3175, rewards: 0.4314\n",
      "Iteration: 3620, Train loss: -0.3479, rewards: 0.4881\n",
      "Iteration: 3630, Train loss: -0.4041, rewards: 0.4289\n",
      "Iteration: 3640, Train loss: -0.4077, rewards: 0.3916\n",
      "Iteration: 3650, Train loss: -0.3889, rewards: 0.4561\n",
      "Iteration: 3660, Train loss: -0.4037, rewards: 0.4173\n",
      "Iteration: 3670, Train loss: -0.4080, rewards: 0.4436\n",
      "Iteration: 3680, Train loss: -0.2958, rewards: 0.4333\n",
      "Iteration: 3690, Train loss: -0.3307, rewards: 0.4148\n",
      "Iteration: 3700, Train loss: -0.2893, rewards: 0.4167\n",
      "Eval:\n",
      "Hits@1: 0.4377, Hits@3: 0.4789, Hits@10: 0.5175, MRR: 0.4637\n",
      "------------------------------------------------------------\n",
      "Iteration: 3710, Train loss: -0.3567, rewards: 0.4102\n",
      "Iteration: 3720, Train loss: -0.3504, rewards: 0.4328\n",
      "Iteration: 3730, Train loss: -0.3582, rewards: 0.4414\n",
      "Iteration: 3740, Train loss: -0.3953, rewards: 0.4739\n",
      "Iteration: 3750, Train loss: -0.3667, rewards: 0.3775\n",
      "Iteration: 3760, Train loss: -0.2939, rewards: 0.3694\n",
      "Iteration: 3770, Train loss: -0.3986, rewards: 0.3956\n",
      "Iteration: 3780, Train loss: -0.3245, rewards: 0.4245\n",
      "Iteration: 3790, Train loss: -0.3333, rewards: 0.4323\n",
      "Iteration: 3800, Train loss: -0.3378, rewards: 0.4111\n",
      "Eval:\n",
      "Hits@1: 0.3991, Hits@3: 0.4726, Hits@10: 0.5168, MRR: 0.4421\n",
      "------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 3810, Train loss: -0.4028, rewards: 0.4345\n",
      "Iteration: 3820, Train loss: -0.4449, rewards: 0.4191\n",
      "Iteration: 3830, Train loss: -0.3642, rewards: 0.4197\n",
      "Iteration: 3840, Train loss: -0.3564, rewards: 0.4062\n",
      "Iteration: 3850, Train loss: -0.3936, rewards: 0.4122\n",
      "Iteration: 3860, Train loss: -0.3405, rewards: 0.4144\n",
      "Iteration: 3870, Train loss: -0.3397, rewards: 0.4541\n",
      "Iteration: 3880, Train loss: -0.3381, rewards: 0.3844\n",
      "Iteration: 3890, Train loss: -0.3414, rewards: 0.3891\n",
      "Iteration: 3900, Train loss: -0.3375, rewards: 0.4288\n",
      "Eval:\n",
      "Hits@1: 0.4311, Hits@3: 0.4730, Hits@10: 0.5148, MRR: 0.4583\n",
      "------------------------------------------------------------\n",
      "Iteration: 3910, Train loss: -0.3951, rewards: 0.4230\n",
      "Iteration: 3920, Train loss: -0.3614, rewards: 0.3639\n",
      "Iteration: 3930, Train loss: -0.3961, rewards: 0.3792\n",
      "Iteration: 3940, Train loss: -0.3429, rewards: 0.3728\n",
      "Iteration: 3950, Train loss: -0.3553, rewards: 0.4517\n",
      "Iteration: 3960, Train loss: -0.3948, rewards: 0.4278\n",
      "Iteration: 3970, Train loss: -0.3672, rewards: 0.3980\n",
      "Iteration: 3980, Train loss: -0.4366, rewards: 0.4227\n",
      "Iteration: 3990, Train loss: -0.4295, rewards: 0.4225\n",
      "Iteration: 4000, Train loss: -0.4248, rewards: 0.4444\n",
      "Eval:\n",
      "Hits@1: 0.4387, Hits@3: 0.4815, Hits@10: 0.5270, MRR: 0.4665\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(options)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d6a50465",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.4445, Hits@3: 0.4939, Hits@10: 0.5325, MRR: 0.4743\n"
     ]
    }
   ],
   "source": [
    "trainer.agent.load_state_dict(torch.load(options['model_dir'] + 'agent.ckpt'))\n",
    "trainer.agent.eval()\n",
    "options['test_rollouts'] = 100\n",
    "options['max_num_actions'] = 100\n",
    "options['eval_batch_size'] = 8\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "test_results = trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cb5b6cee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.4445, Hits@3: 0.4939, Hits@10: 0.5325, MRR: 0.4743\n"
     ]
    }
   ],
   "source": [
    "print(test_results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
