{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0f115645",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_params(LSTM_layers = 1, batch_size = 8, beta = 0.15, Lambda = 0.15, learning_rate = 5e-5):\n",
    "    options = {}\n",
    "\n",
    "    #basic setting\n",
    "    options['use_cuda'] = True\n",
    "    options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/vocab/'\n",
    "    options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/'\n",
    "    options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "    options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "    options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "    options['model_dir'] = './outputs_WN18RR-tune/'\n",
    "    options['output_dir'] = './outputs_WN18RR-tune/'\n",
    "\n",
    "    #agent setting\n",
    "    options['pretrained_embeddings_relation'] = {}\n",
    "    options['pretrained_embeddings_entity'] = {}\n",
    "    options['embedding_size'] = 50\n",
    "    options['hidden_size'] = 200\n",
    "    options['use_entity_embeddings'] = 1\n",
    "    options['train_entity_embeddings'] = 0\n",
    "    options['train_relation_embeddings'] = 1\n",
    "    options['path_length'] = 3\n",
    "    options['LSTM_layers'] = LSTM_layers\n",
    "    options['max_num_actions'] = 40\n",
    "\n",
    "    #hyperparameters\n",
    "    options['test_rollouts'] = 40\n",
    "    options['num_rollouts'] = 20\n",
    "    options['batch_size'] = batch_size\n",
    "    options['eval_batch_size'] = 32\n",
    "    options['beta'] = beta\n",
    "    options['Lambda'] = Lambda\n",
    "    options['gamma'] = 1\n",
    "    options['positive_reward'] = 1\n",
    "    options['negative_reward'] = 0\n",
    "    options['learning_rate'] = learning_rate\n",
    "    options['grad_clip_norm'] = 100\n",
    "    options['eval_every'] = 100\n",
    "    options['total_iterations'] = 2000*(64/batch_size)\n",
    "    options['pool'] = 'max'\n",
    "    \n",
    "    return options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d0332ee",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/Ours/model/ours.py:333: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.2352, rewards: 0.2286\n",
      "Iteration: 20, Train loss: -0.2213, rewards: 0.3075\n",
      "Iteration: 30, Train loss: -0.2205, rewards: 0.2995\n",
      "Iteration: 40, Train loss: -0.2066, rewards: 0.3436\n",
      "Iteration: 50, Train loss: -0.2347, rewards: 0.3289\n",
      "Iteration: 60, Train loss: -0.2174, rewards: 0.3039\n",
      "Iteration: 70, Train loss: -0.2938, rewards: 0.3417\n",
      "Iteration: 80, Train loss: -0.2647, rewards: 0.3088\n",
      "Iteration: 90, Train loss: -0.3751, rewards: 0.3191\n",
      "Iteration: 100, Train loss: -0.4263, rewards: 0.3286\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/Research/GraphRL/Ours/model/ours.py:635: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.3494, Hits@3: 0.4272, Hits@10: 0.4852, MRR: 0.3961\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.3682, rewards: 0.3652\n",
      "Iteration: 120, Train loss: -0.2816, rewards: 0.3492\n",
      "Iteration: 130, Train loss: -0.4262, rewards: 0.3702\n",
      "Iteration: 140, Train loss: -0.3576, rewards: 0.4095\n",
      "Iteration: 150, Train loss: -0.4033, rewards: 0.3928\n",
      "Iteration: 160, Train loss: -0.4529, rewards: 0.3855\n",
      "Iteration: 170, Train loss: -0.4177, rewards: 0.3539\n",
      "Iteration: 180, Train loss: -0.4087, rewards: 0.3950\n",
      "Iteration: 190, Train loss: -0.5372, rewards: 0.3700\n",
      "Iteration: 200, Train loss: -0.4842, rewards: 0.3852\n",
      "Eval:\n",
      "Hits@1: 0.3606, Hits@3: 0.4291, Hits@10: 0.4819, MRR: 0.4015\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.4232, rewards: 0.3716\n",
      "Iteration: 220, Train loss: -0.4862, rewards: 0.3833\n",
      "Iteration: 230, Train loss: -0.4696, rewards: 0.3842\n",
      "Iteration: 240, Train loss: -0.4850, rewards: 0.3525\n",
      "Iteration: 250, Train loss: -0.5442, rewards: 0.4189\n",
      "Iteration: 260, Train loss: -0.4899, rewards: 0.3797\n",
      "Iteration: 270, Train loss: -0.5087, rewards: 0.3619\n",
      "Iteration: 280, Train loss: -0.5139, rewards: 0.3847\n",
      "Iteration: 290, Train loss: -0.5621, rewards: 0.4111\n",
      "Iteration: 300, Train loss: -0.5294, rewards: 0.3817\n",
      "Eval:\n",
      "Hits@1: 0.3655, Hits@3: 0.4347, Hits@10: 0.4980, MRR: 0.4087\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.5172, rewards: 0.3614\n",
      "Iteration: 320, Train loss: -0.4884, rewards: 0.3427\n",
      "Iteration: 330, Train loss: -0.5227, rewards: 0.4147\n",
      "Iteration: 340, Train loss: -0.5013, rewards: 0.3992\n",
      "Iteration: 350, Train loss: -0.5892, rewards: 0.4134\n",
      "Iteration: 360, Train loss: -0.5045, rewards: 0.3902\n",
      "Iteration: 370, Train loss: -0.5623, rewards: 0.4037\n",
      "Iteration: 380, Train loss: -0.5136, rewards: 0.4128\n",
      "Iteration: 390, Train loss: -0.5185, rewards: 0.3073\n",
      "Iteration: 400, Train loss: -0.4731, rewards: 0.3872\n",
      "Eval:\n",
      "Hits@1: 0.3645, Hits@3: 0.4423, Hits@10: 0.4964, MRR: 0.4109\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.5333, rewards: 0.3517\n",
      "Iteration: 420, Train loss: -0.5431, rewards: 0.3592\n",
      "Iteration: 430, Train loss: -0.5583, rewards: 0.3914\n",
      "Iteration: 440, Train loss: -0.4357, rewards: 0.4078\n",
      "Iteration: 450, Train loss: -0.3779, rewards: 0.3692\n",
      "Iteration: 460, Train loss: -0.4288, rewards: 0.3939\n",
      "Iteration: 470, Train loss: -0.4839, rewards: 0.3553\n",
      "Iteration: 480, Train loss: -0.5509, rewards: 0.3997\n",
      "Iteration: 490, Train loss: -0.4765, rewards: 0.4208\n",
      "Iteration: 500, Train loss: -0.5049, rewards: 0.3917\n",
      "Eval:\n",
      "Hits@1: 0.3579, Hits@3: 0.4354, Hits@10: 0.4937, MRR: 0.4048\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.4968, rewards: 0.3698\n",
      "Iteration: 520, Train loss: -0.4336, rewards: 0.4002\n",
      "Iteration: 530, Train loss: -0.3737, rewards: 0.3984\n",
      "Iteration: 540, Train loss: -0.2576, rewards: 0.4441\n",
      "Iteration: 550, Train loss: -0.2979, rewards: 0.4348\n",
      "Iteration: 560, Train loss: -0.3925, rewards: 0.3967\n",
      "Iteration: 570, Train loss: -0.4086, rewards: 0.4347\n",
      "Iteration: 580, Train loss: -0.4664, rewards: 0.4233\n",
      "Iteration: 590, Train loss: -0.4605, rewards: 0.3542\n",
      "Iteration: 600, Train loss: -0.5024, rewards: 0.3566\n",
      "Eval:\n",
      "Hits@1: 0.3916, Hits@3: 0.4479, Hits@10: 0.4970, MRR: 0.4275\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.5573, rewards: 0.3872\n",
      "Iteration: 620, Train loss: -0.6019, rewards: 0.3873\n",
      "Iteration: 630, Train loss: -0.4745, rewards: 0.4186\n",
      "Iteration: 640, Train loss: -0.4856, rewards: 0.3758\n",
      "Iteration: 650, Train loss: -0.4726, rewards: 0.4239\n",
      "Iteration: 660, Train loss: -0.4983, rewards: 0.3816\n",
      "Iteration: 670, Train loss: -0.4133, rewards: 0.3633\n",
      "Iteration: 680, Train loss: -0.5086, rewards: 0.3692\n",
      "Iteration: 690, Train loss: -0.4763, rewards: 0.4128\n",
      "Iteration: 700, Train loss: -0.5007, rewards: 0.4120\n",
      "Eval:\n",
      "Hits@1: 0.3659, Hits@3: 0.4407, Hits@10: 0.4931, MRR: 0.4109\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.5278, rewards: 0.4070\n",
      "Iteration: 720, Train loss: -0.5042, rewards: 0.3791\n",
      "Iteration: 730, Train loss: -0.5266, rewards: 0.4242\n",
      "Iteration: 740, Train loss: -0.5076, rewards: 0.4059\n",
      "Iteration: 750, Train loss: -0.4496, rewards: 0.3814\n",
      "Iteration: 760, Train loss: -0.4842, rewards: 0.4102\n",
      "Iteration: 770, Train loss: -0.4545, rewards: 0.3887\n",
      "Iteration: 780, Train loss: -0.4897, rewards: 0.3866\n",
      "Iteration: 790, Train loss: -0.4592, rewards: 0.4195\n",
      "Iteration: 800, Train loss: -0.5127, rewards: 0.4223\n",
      "Eval:\n",
      "Hits@1: 0.3649, Hits@3: 0.4338, Hits@10: 0.4927, MRR: 0.4078\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.4590, rewards: 0.3827\n",
      "Iteration: 820, Train loss: -0.4130, rewards: 0.3463\n",
      "Iteration: 830, Train loss: -0.5154, rewards: 0.3970\n",
      "Iteration: 840, Train loss: -0.4916, rewards: 0.4081\n",
      "Iteration: 850, Train loss: -0.4978, rewards: 0.3931\n",
      "Iteration: 860, Train loss: -0.4922, rewards: 0.4119\n",
      "Iteration: 870, Train loss: -0.5308, rewards: 0.3736\n",
      "Iteration: 880, Train loss: -0.5252, rewards: 0.4048\n",
      "Iteration: 890, Train loss: -0.4811, rewards: 0.4148\n",
      "Iteration: 900, Train loss: -0.5112, rewards: 0.3762\n",
      "Eval:\n",
      "Hits@1: 0.3550, Hits@3: 0.4281, Hits@10: 0.4885, MRR: 0.4000\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.5533, rewards: 0.4181\n",
      "Iteration: 920, Train loss: -0.5267, rewards: 0.3706\n",
      "Iteration: 930, Train loss: -0.4704, rewards: 0.4295\n",
      "Iteration: 940, Train loss: -0.5029, rewards: 0.4238\n",
      "Iteration: 950, Train loss: -0.5894, rewards: 0.3916\n",
      "Iteration: 960, Train loss: -0.5273, rewards: 0.3761\n",
      "Iteration: 970, Train loss: -0.4362, rewards: 0.3941\n",
      "Iteration: 980, Train loss: -0.4933, rewards: 0.3583\n",
      "Iteration: 990, Train loss: -0.4465, rewards: 0.3945\n",
      "Iteration: 1000, Train loss: -0.4022, rewards: 0.4119\n",
      "Eval:\n",
      "Hits@1: 0.3652, Hits@3: 0.4364, Hits@10: 0.5053, MRR: 0.4114\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.4416, rewards: 0.3538\n",
      "Iteration: 1020, Train loss: -0.6111, rewards: 0.4069\n",
      "Iteration: 1030, Train loss: -0.5450, rewards: 0.3856\n",
      "Iteration: 1040, Train loss: -0.4886, rewards: 0.3892\n",
      "Iteration: 1050, Train loss: -0.5127, rewards: 0.3400\n",
      "Iteration: 1060, Train loss: -0.5039, rewards: 0.3825\n",
      "Iteration: 1070, Train loss: -0.5343, rewards: 0.4323\n",
      "Iteration: 1080, Train loss: -0.5528, rewards: 0.4200\n",
      "Iteration: 1090, Train loss: -0.4584, rewards: 0.3812\n",
      "Iteration: 1100, Train loss: -0.4624, rewards: 0.3858\n",
      "Eval:\n",
      "Hits@1: 0.3764, Hits@3: 0.4489, Hits@10: 0.5073, MRR: 0.4210\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.4619, rewards: 0.4334\n",
      "Iteration: 1120, Train loss: -0.4670, rewards: 0.4772\n",
      "Iteration: 1130, Train loss: -0.4251, rewards: 0.3917\n",
      "Iteration: 1140, Train loss: -0.4444, rewards: 0.4044\n",
      "Iteration: 1150, Train loss: -0.4852, rewards: 0.3928\n",
      "Iteration: 1160, Train loss: -0.4654, rewards: 0.3783\n",
      "Iteration: 1170, Train loss: -0.5114, rewards: 0.3830\n",
      "Iteration: 1180, Train loss: -0.5313, rewards: 0.3855\n",
      "Iteration: 1190, Train loss: -0.5022, rewards: 0.3517\n",
      "Iteration: 1200, Train loss: -0.5305, rewards: 0.3925\n",
      "Eval:\n",
      "Hits@1: 0.3672, Hits@3: 0.4347, Hits@10: 0.4895, MRR: 0.4088\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.5174, rewards: 0.3528\n",
      "Iteration: 1220, Train loss: -0.4620, rewards: 0.4139\n",
      "Iteration: 1230, Train loss: -0.4752, rewards: 0.4247\n",
      "Iteration: 1240, Train loss: -0.5078, rewards: 0.3994\n",
      "Iteration: 1250, Train loss: -0.4839, rewards: 0.3966\n",
      "Iteration: 1260, Train loss: -0.5402, rewards: 0.4113\n",
      "Iteration: 1270, Train loss: -0.4751, rewards: 0.4134\n",
      "Iteration: 1280, Train loss: -0.5248, rewards: 0.3675\n",
      "Iteration: 1290, Train loss: -0.5356, rewards: 0.4400\n",
      "Iteration: 1300, Train loss: -0.5424, rewards: 0.4370\n",
      "Eval:\n",
      "Hits@1: 0.3734, Hits@3: 0.4456, Hits@10: 0.4980, MRR: 0.4163\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: -0.5897, rewards: 0.3755\n",
      "Iteration: 1320, Train loss: -0.5349, rewards: 0.3422\n",
      "Iteration: 1330, Train loss: -0.6046, rewards: 0.3645\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1340, Train loss: -0.5763, rewards: 0.4381\n",
      "Iteration: 1350, Train loss: -0.5082, rewards: 0.3933\n",
      "Iteration: 1360, Train loss: -0.4859, rewards: 0.3934\n",
      "Iteration: 1370, Train loss: -0.4900, rewards: 0.3600\n",
      "Iteration: 1380, Train loss: -0.5559, rewards: 0.3669\n",
      "Iteration: 1390, Train loss: -0.4953, rewards: 0.4077\n",
      "Iteration: 1400, Train loss: -0.4908, rewards: 0.4761\n",
      "Eval:\n",
      "Hits@1: 0.3902, Hits@3: 0.4502, Hits@10: 0.5049, MRR: 0.4281\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.4641, rewards: 0.3669\n",
      "Iteration: 1420, Train loss: -0.4711, rewards: 0.3423\n",
      "Iteration: 1430, Train loss: -0.5104, rewards: 0.3627\n",
      "Iteration: 1440, Train loss: -0.5037, rewards: 0.3861\n",
      "Iteration: 1450, Train loss: -0.4861, rewards: 0.4520\n",
      "Iteration: 1460, Train loss: -0.4966, rewards: 0.3958\n",
      "Iteration: 1470, Train loss: -0.4845, rewards: 0.3820\n",
      "Iteration: 1480, Train loss: -0.4815, rewards: 0.4098\n",
      "Iteration: 1490, Train loss: -0.5406, rewards: 0.3902\n",
      "Iteration: 1500, Train loss: -0.4455, rewards: 0.4452\n",
      "Eval:\n",
      "Hits@1: 0.3754, Hits@3: 0.4440, Hits@10: 0.4987, MRR: 0.4177\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.5088, rewards: 0.3748\n",
      "Iteration: 1520, Train loss: -0.5123, rewards: 0.4261\n",
      "Iteration: 1530, Train loss: -0.4813, rewards: 0.3391\n",
      "Iteration: 1540, Train loss: -0.5246, rewards: 0.4241\n",
      "Iteration: 1550, Train loss: -0.4800, rewards: 0.4175\n",
      "Iteration: 1560, Train loss: -0.5580, rewards: 0.3933\n",
      "Iteration: 1570, Train loss: -0.5015, rewards: 0.3663\n",
      "Iteration: 1580, Train loss: -0.5008, rewards: 0.3980\n",
      "Iteration: 1590, Train loss: -0.5064, rewards: 0.3792\n",
      "Iteration: 1600, Train loss: -0.4957, rewards: 0.3983\n",
      "Eval:\n",
      "Hits@1: 0.3691, Hits@3: 0.4351, Hits@10: 0.4967, MRR: 0.4115\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.4976, rewards: 0.4437\n",
      "Iteration: 1620, Train loss: -0.6207, rewards: 0.3948\n",
      "Iteration: 1630, Train loss: -0.5065, rewards: 0.4173\n",
      "Iteration: 1640, Train loss: -0.5362, rewards: 0.4341\n",
      "Iteration: 1650, Train loss: -0.5006, rewards: 0.3748\n",
      "Iteration: 1660, Train loss: -0.4955, rewards: 0.4205\n",
      "Iteration: 1670, Train loss: -0.5440, rewards: 0.3738\n",
      "Iteration: 1680, Train loss: -0.5367, rewards: 0.4148\n",
      "Iteration: 1690, Train loss: -0.4311, rewards: 0.3736\n",
      "Iteration: 1700, Train loss: -0.4649, rewards: 0.4412\n",
      "Eval:\n",
      "Hits@1: 0.3873, Hits@3: 0.4525, Hits@10: 0.5046, MRR: 0.4274\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.4402, rewards: 0.4352\n",
      "Iteration: 1720, Train loss: -0.4998, rewards: 0.3911\n",
      "Iteration: 1730, Train loss: -0.5648, rewards: 0.3864\n",
      "Iteration: 1740, Train loss: -0.5155, rewards: 0.4069\n",
      "Iteration: 1750, Train loss: -0.4451, rewards: 0.3847\n",
      "Iteration: 1760, Train loss: -0.4895, rewards: 0.4120\n",
      "Iteration: 1770, Train loss: -0.4561, rewards: 0.4541\n",
      "Iteration: 1780, Train loss: -0.4033, rewards: 0.4272\n",
      "Iteration: 1790, Train loss: -0.4599, rewards: 0.3837\n",
      "Iteration: 1800, Train loss: -0.4854, rewards: 0.4227\n",
      "Eval:\n",
      "Hits@1: 0.3767, Hits@3: 0.4535, Hits@10: 0.5138, MRR: 0.4233\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.5433, rewards: 0.4134\n",
      "Iteration: 1820, Train loss: -0.5205, rewards: 0.4323\n",
      "Iteration: 1830, Train loss: -0.4641, rewards: 0.3875\n",
      "Iteration: 1840, Train loss: -0.4740, rewards: 0.3817\n",
      "Iteration: 1850, Train loss: -0.4787, rewards: 0.4066\n",
      "Iteration: 1860, Train loss: -0.4911, rewards: 0.3903\n",
      "Iteration: 1870, Train loss: -0.4952, rewards: 0.3992\n",
      "Iteration: 1880, Train loss: -0.5466, rewards: 0.3556\n",
      "Iteration: 1890, Train loss: -0.5488, rewards: 0.4161\n",
      "Iteration: 1900, Train loss: -0.4677, rewards: 0.4539\n",
      "Eval:\n",
      "Hits@1: 0.3619, Hits@3: 0.4433, Hits@10: 0.5010, MRR: 0.4101\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.5002, rewards: 0.4309\n",
      "Iteration: 1920, Train loss: -0.5057, rewards: 0.3980\n",
      "Iteration: 1930, Train loss: -0.5466, rewards: 0.3583\n",
      "Iteration: 1940, Train loss: -0.5237, rewards: 0.3642\n",
      "Iteration: 1950, Train loss: -0.5488, rewards: 0.4200\n",
      "Iteration: 1960, Train loss: -0.5236, rewards: 0.3836\n",
      "Iteration: 1970, Train loss: -0.4766, rewards: 0.4048\n",
      "Iteration: 1980, Train loss: -0.5203, rewards: 0.4170\n",
      "Iteration: 1990, Train loss: -0.5224, rewards: 0.3794\n",
      "Iteration: 2000, Train loss: -0.4750, rewards: 0.4130\n",
      "Eval:\n",
      "Hits@1: 0.3761, Hits@3: 0.4463, Hits@10: 0.5023, MRR: 0.4181\n",
      "------------------------------------------------------------\n",
      "Iteration: 2010, Train loss: -0.5256, rewards: 0.3580\n",
      "Iteration: 2020, Train loss: -0.4783, rewards: 0.4277\n",
      "Iteration: 2030, Train loss: -0.4937, rewards: 0.4139\n",
      "Iteration: 2040, Train loss: -0.5024, rewards: 0.4133\n",
      "Iteration: 2050, Train loss: -0.5454, rewards: 0.3761\n",
      "Iteration: 2060, Train loss: -0.4918, rewards: 0.3822\n",
      "Iteration: 2070, Train loss: -0.5734, rewards: 0.4150\n",
      "Iteration: 2080, Train loss: -0.5659, rewards: 0.4181\n",
      "Iteration: 2090, Train loss: -0.5752, rewards: 0.3966\n",
      "Iteration: 2100, Train loss: -0.5540, rewards: 0.3981\n",
      "Eval:\n",
      "Hits@1: 0.3767, Hits@3: 0.4466, Hits@10: 0.5010, MRR: 0.4186\n",
      "------------------------------------------------------------\n",
      "Iteration: 2110, Train loss: -0.4907, rewards: 0.4248\n",
      "Iteration: 2120, Train loss: -0.5178, rewards: 0.4253\n",
      "Iteration: 2130, Train loss: -0.5166, rewards: 0.4062\n",
      "Iteration: 2140, Train loss: -0.4671, rewards: 0.4205\n",
      "Iteration: 2150, Train loss: -0.5782, rewards: 0.4058\n",
      "Iteration: 2160, Train loss: -0.5625, rewards: 0.3830\n",
      "Iteration: 2170, Train loss: -0.4725, rewards: 0.3928\n",
      "Iteration: 2180, Train loss: -0.4872, rewards: 0.3733\n",
      "Iteration: 2190, Train loss: -0.4691, rewards: 0.3900\n",
      "Iteration: 2200, Train loss: -0.5031, rewards: 0.4069\n",
      "Eval:\n",
      "Hits@1: 0.3691, Hits@3: 0.4417, Hits@10: 0.4980, MRR: 0.4132\n",
      "------------------------------------------------------------\n",
      "Iteration: 2210, Train loss: -0.5091, rewards: 0.3870\n",
      "Iteration: 2220, Train loss: -0.5078, rewards: 0.3933\n",
      "Iteration: 2230, Train loss: -0.5079, rewards: 0.4314\n",
      "Iteration: 2240, Train loss: -0.4618, rewards: 0.4264\n",
      "Iteration: 2250, Train loss: -0.4705, rewards: 0.4228\n",
      "Iteration: 2260, Train loss: -0.4230, rewards: 0.3858\n",
      "Iteration: 2270, Train loss: -0.5019, rewards: 0.4084\n",
      "Iteration: 2280, Train loss: -0.4646, rewards: 0.3555\n",
      "Iteration: 2290, Train loss: -0.4296, rewards: 0.3639\n",
      "Iteration: 2300, Train loss: -0.4436, rewards: 0.3859\n",
      "Eval:\n",
      "Hits@1: 0.3659, Hits@3: 0.4456, Hits@10: 0.5026, MRR: 0.4130\n",
      "------------------------------------------------------------\n",
      "Iteration: 2310, Train loss: -0.4323, rewards: 0.3916\n",
      "Iteration: 2320, Train loss: -0.4276, rewards: 0.4275\n",
      "Iteration: 2330, Train loss: -0.4569, rewards: 0.3861\n",
      "Iteration: 2340, Train loss: -0.4342, rewards: 0.4020\n",
      "Iteration: 2350, Train loss: -0.4960, rewards: 0.3641\n",
      "Iteration: 2360, Train loss: -0.5381, rewards: 0.3908\n",
      "Iteration: 2370, Train loss: -0.4928, rewards: 0.4155\n",
      "Iteration: 2380, Train loss: -0.4604, rewards: 0.4328\n",
      "Iteration: 2390, Train loss: -0.4381, rewards: 0.4048\n",
      "Iteration: 2400, Train loss: -0.5080, rewards: 0.3855\n",
      "Eval:\n",
      "Hits@1: 0.3998, Hits@3: 0.4591, Hits@10: 0.5069, MRR: 0.4365\n",
      "------------------------------------------------------------\n",
      "Iteration: 2410, Train loss: -0.4742, rewards: 0.4342\n",
      "Iteration: 2420, Train loss: -0.4692, rewards: 0.4041\n",
      "Iteration: 2430, Train loss: -0.4252, rewards: 0.4006\n",
      "Iteration: 2440, Train loss: -0.4279, rewards: 0.4356\n",
      "Iteration: 2450, Train loss: -0.4766, rewards: 0.4084\n",
      "Iteration: 2460, Train loss: -0.5101, rewards: 0.4645\n",
      "Iteration: 2470, Train loss: -0.5420, rewards: 0.4325\n",
      "Iteration: 2480, Train loss: -0.5144, rewards: 0.3738\n",
      "Iteration: 2490, Train loss: -0.4908, rewards: 0.3950\n",
      "Iteration: 2500, Train loss: -0.5153, rewards: 0.4528\n",
      "Eval:\n",
      "Hits@1: 0.3906, Hits@3: 0.4519, Hits@10: 0.5040, MRR: 0.4273\n",
      "------------------------------------------------------------\n",
      "Iteration: 2510, Train loss: -0.4961, rewards: 0.4070\n",
      "Iteration: 2520, Train loss: -0.5245, rewards: 0.3645\n",
      "Iteration: 2530, Train loss: -0.5382, rewards: 0.3639\n",
      "Iteration: 2540, Train loss: -0.5380, rewards: 0.3634\n",
      "Iteration: 2550, Train loss: -0.5528, rewards: 0.4103\n",
      "Iteration: 2560, Train loss: -0.4877, rewards: 0.3736\n",
      "Iteration: 2570, Train loss: -0.5634, rewards: 0.3975\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 2580, Train loss: -0.5368, rewards: 0.3777\n",
      "Iteration: 2590, Train loss: -0.5097, rewards: 0.3720\n",
      "Iteration: 2600, Train loss: -0.5399, rewards: 0.3967\n",
      "Eval:\n",
      "Hits@1: 0.3701, Hits@3: 0.4466, Hits@10: 0.5020, MRR: 0.4155\n",
      "------------------------------------------------------------\n",
      "Iteration: 2610, Train loss: -0.5514, rewards: 0.4016\n",
      "Iteration: 2620, Train loss: -0.4761, rewards: 0.4088\n",
      "Iteration: 2630, Train loss: -0.5301, rewards: 0.3973\n",
      "Iteration: 2640, Train loss: -0.4549, rewards: 0.4280\n",
      "Iteration: 2650, Train loss: -0.5111, rewards: 0.4173\n",
      "Iteration: 2660, Train loss: -0.5181, rewards: 0.4075\n",
      "Iteration: 2670, Train loss: -0.4526, rewards: 0.4458\n",
      "Iteration: 2680, Train loss: -0.5204, rewards: 0.4114\n",
      "Iteration: 2690, Train loss: -0.5579, rewards: 0.4177\n",
      "Iteration: 2700, Train loss: -0.5275, rewards: 0.4291\n",
      "Eval:\n",
      "Hits@1: 0.3837, Hits@3: 0.4492, Hits@10: 0.5066, MRR: 0.4247\n",
      "------------------------------------------------------------\n",
      "Iteration: 2710, Train loss: -0.5072, rewards: 0.4134\n",
      "Iteration: 2720, Train loss: -0.5552, rewards: 0.4570\n",
      "Iteration: 2730, Train loss: -0.4993, rewards: 0.4400\n",
      "Iteration: 2740, Train loss: -0.4855, rewards: 0.4103\n",
      "Iteration: 2750, Train loss: -0.4725, rewards: 0.4119\n",
      "Iteration: 2760, Train loss: -0.4538, rewards: 0.3723\n",
      "Iteration: 2770, Train loss: -0.4702, rewards: 0.3997\n",
      "Iteration: 2780, Train loss: -0.4686, rewards: 0.3742\n",
      "Iteration: 2790, Train loss: -0.4992, rewards: 0.4023\n",
      "Iteration: 2800, Train loss: -0.5395, rewards: 0.4500\n",
      "Eval:\n",
      "Hits@1: 0.3619, Hits@3: 0.4384, Hits@10: 0.4941, MRR: 0.4081\n",
      "------------------------------------------------------------\n",
      "Iteration: 2810, Train loss: -0.5412, rewards: 0.3630\n",
      "Iteration: 2820, Train loss: -0.5528, rewards: 0.3827\n",
      "Iteration: 2830, Train loss: -0.5404, rewards: 0.3887\n"
     ]
    }
   ],
   "source": [
    "from model.ours import *\n",
    "\n",
    "results = {}\n",
    "for layer in [1, 2]:\n",
    "    for bl in [0.05, 0.08, 0.02, 0.1, 0.12, 0.15]:\n",
    "        for bs in [32, 64, 128]:\n",
    "            for lr in [5e-4, 1e-4, 1e-3, 5e-5, 2e-3]:\n",
    "                params = set_params(layer, bs, bl, bl, lr)\n",
    "                name = f'{layer}-{bs}-{bl}-{bl}-{lr}'\n",
    "\n",
    "                trainer = Trainer(params)\n",
    "                trainer.train()\n",
    "                torch.cuda.empty_cache()\n",
    "\n",
    "                trainer.agent.load_state_dict(torch.load(params['model_dir'] + 'agent.ckpt'))\n",
    "                trainer.agent.eval()\n",
    "                trainer.test_environment = trainer.test_test_environment\n",
    "                tmp = trainer.test(beam=True, print_paths=False, save_model=False)\n",
    "\n",
    "                print(name)\n",
    "                print(tmp)\n",
    "                print('-------------')\n",
    "                results[name] = tmp\n",
    "\n",
    "                with open(params['output_dir'] + 'results_table.pk5', 'wb') as f:\n",
    "                    pickle5.dump(results, f)\n",
    "                    \n",
    "                del trainer\n",
    "                gc.collect()\n",
    "                torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
