{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-6abedaa4-16cd-51b2-9b2f-043073ed897a\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0f115645",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_params(LSTM_layers = 1, batch_size = 8, beta = 0.15, Lambda = 0.15, learning_rate = 5e-5):\n",
    "    options = {}\n",
    "\n",
    "    #basic setting\n",
    "    options['use_cuda'] = True\n",
    "    options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/vocab/'\n",
    "    options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/'\n",
    "    options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "    options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "    options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "    options['model_dir'] = './outputs_WN18RR-tune2/'\n",
    "    options['output_dir'] = './outputs_WN18RR-tune2/'\n",
    "\n",
    "    #agent setting\n",
    "    options['pretrained_embeddings_relation'] = {}\n",
    "    options['pretrained_embeddings_entity'] = {}\n",
    "    options['embedding_size'] = 50\n",
    "    options['hidden_size'] = 200\n",
    "    options['use_entity_embeddings'] = 1\n",
    "    options['train_entity_embeddings'] = 1\n",
    "    options['train_relation_embeddings'] = 1\n",
    "    options['path_length'] = 3\n",
    "    options['LSTM_layers'] = LSTM_layers\n",
    "    options['max_num_actions'] = 40\n",
    "    options['gnn_layer'] = 1\n",
    "\n",
    "    #hyperparameters\n",
    "    options['test_rollouts'] = 40\n",
    "    options['num_rollouts'] = 20\n",
    "    options['batch_size'] = batch_size\n",
    "    options['eval_batch_size'] = 32\n",
    "    options['beta'] = beta\n",
    "    options['Lambda'] = Lambda\n",
    "    options['gamma'] = 1\n",
    "    options['positive_reward'] = 1\n",
    "    options['negative_reward'] = 0\n",
    "    options['learning_rate'] = learning_rate\n",
    "    options['grad_clip_norm'] = 100\n",
    "    options['eval_every'] = 100\n",
    "    options['total_iterations'] = 2000*(64/batch_size)\n",
    "    options['pool'] = 'max'\n",
    "    \n",
    "    return options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7d0332ee",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# from model.ours2 import *\n",
    "\n",
    "# results = {}\n",
    "# for layer in [1, 2]:\n",
    "#     for bl in [0.05, 0.08, 0.02, 0.1, 0.12, 0.15]:\n",
    "#         for bs in [32, 64, 128]:\n",
    "#             for lr in [5e-4, 1e-4, 1e-3, 5e-5, 2e-3]:\n",
    "#                 params = set_params(layer, bs, bl, bl, lr)\n",
    "#                 name = f'{layer}-{bs}-{bl}-{bl}-{lr}'\n",
    "\n",
    "#                 trainer = Trainer(params)\n",
    "#                 trainer.train()\n",
    "#                 torch.cuda.empty_cache()\n",
    "\n",
    "#                 trainer.agent.load_state_dict(torch.load(params['model_dir'] + 'agent.ckpt'))\n",
    "#                 trainer.agent.eval()\n",
    "#                 trainer.test_environment = trainer.test_test_environment\n",
    "#                 tmp = trainer.test(beam=True, print_paths=False, save_model=False)\n",
    "\n",
    "#                 print(name)\n",
    "#                 print(tmp)\n",
    "#                 print('-------------')\n",
    "#                 results[name] = tmp\n",
    "\n",
    "#                 with open(params['output_dir'] + 'results_table.pk5', 'wb') as f:\n",
    "#                     pickle5.dump(results, f)\n",
    "                    \n",
    "#                 del trainer\n",
    "#                 gc.collect()\n",
    "#                 torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b8716a8a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/Ours/model/ours2.py:334: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.2077, rewards: 0.1522\n",
      "Iteration: 20, Train loss: -0.3059, rewards: 0.2311\n",
      "Iteration: 30, Train loss: -0.2813, rewards: 0.2542\n",
      "Iteration: 40, Train loss: -0.3489, rewards: 0.2687\n",
      "Iteration: 50, Train loss: -0.3541, rewards: 0.3033\n",
      "Iteration: 60, Train loss: -0.3636, rewards: 0.3250\n",
      "Iteration: 70, Train loss: -0.3497, rewards: 0.2961\n",
      "Iteration: 80, Train loss: -0.3760, rewards: 0.3663\n",
      "Iteration: 90, Train loss: -0.4077, rewards: 0.3611\n",
      "Iteration: 100, Train loss: -0.4282, rewards: 0.3906\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/Research/GraphRL/Ours/model/ours2.py:636: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.3705, Hits@3: 0.4483, Hits@10: 0.4977, MRR: 0.4167\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.3429, rewards: 0.3552\n",
      "Iteration: 120, Train loss: -0.3236, rewards: 0.3048\n",
      "Iteration: 130, Train loss: -0.3497, rewards: 0.3870\n",
      "Iteration: 140, Train loss: -0.3888, rewards: 0.3955\n",
      "Iteration: 150, Train loss: -0.2997, rewards: 0.3048\n",
      "Iteration: 160, Train loss: -0.3037, rewards: 0.3683\n",
      "Iteration: 170, Train loss: -0.3282, rewards: 0.3975\n",
      "Iteration: 180, Train loss: -0.3246, rewards: 0.3516\n",
      "Iteration: 190, Train loss: -0.3446, rewards: 0.4217\n",
      "Iteration: 200, Train loss: -0.3519, rewards: 0.4161\n",
      "Eval:\n",
      "Hits@1: 0.3968, Hits@3: 0.4522, Hits@10: 0.4993, MRR: 0.4313\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.3477, rewards: 0.3969\n",
      "Iteration: 220, Train loss: -0.3160, rewards: 0.3761\n",
      "Iteration: 230, Train loss: -0.3248, rewards: 0.3948\n",
      "Iteration: 240, Train loss: -0.3747, rewards: 0.4600\n",
      "Iteration: 250, Train loss: -0.3472, rewards: 0.3477\n",
      "Iteration: 260, Train loss: -0.3761, rewards: 0.4064\n",
      "Iteration: 270, Train loss: -0.3802, rewards: 0.4098\n",
      "Iteration: 280, Train loss: -0.3655, rewards: 0.3809\n",
      "Iteration: 290, Train loss: -0.4101, rewards: 0.4080\n",
      "Iteration: 300, Train loss: -0.3901, rewards: 0.4102\n",
      "Eval:\n",
      "Hits@1: 0.3929, Hits@3: 0.4509, Hits@10: 0.5003, MRR: 0.4283\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.3571, rewards: 0.3548\n",
      "Iteration: 320, Train loss: -0.3580, rewards: 0.4062\n",
      "Iteration: 330, Train loss: -0.2801, rewards: 0.3750\n",
      "Iteration: 340, Train loss: -0.4047, rewards: 0.4181\n",
      "Iteration: 350, Train loss: -0.3884, rewards: 0.4291\n",
      "Iteration: 360, Train loss: -0.3326, rewards: 0.4402\n",
      "Iteration: 370, Train loss: -0.3263, rewards: 0.4170\n",
      "Iteration: 380, Train loss: -0.3530, rewards: 0.3903\n",
      "Iteration: 390, Train loss: -0.2971, rewards: 0.3866\n",
      "Iteration: 400, Train loss: -0.2866, rewards: 0.4198\n",
      "Eval:\n",
      "Hits@1: 0.3754, Hits@3: 0.4595, Hits@10: 0.5079, MRR: 0.4229\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.3467, rewards: 0.4173\n",
      "Iteration: 420, Train loss: -0.3444, rewards: 0.3823\n",
      "Iteration: 430, Train loss: -0.3540, rewards: 0.3847\n",
      "Iteration: 440, Train loss: -0.3147, rewards: 0.4139\n",
      "Iteration: 450, Train loss: -0.2978, rewards: 0.4048\n",
      "Iteration: 460, Train loss: -0.3231, rewards: 0.4367\n",
      "Iteration: 470, Train loss: -0.3488, rewards: 0.4105\n",
      "Iteration: 480, Train loss: -0.3269, rewards: 0.4391\n",
      "Iteration: 490, Train loss: -0.3845, rewards: 0.4120\n",
      "Iteration: 500, Train loss: -0.3882, rewards: 0.4225\n",
      "Eval:\n",
      "Hits@1: 0.4038, Hits@3: 0.4608, Hits@10: 0.5096, MRR: 0.4389\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.3572, rewards: 0.4306\n",
      "Iteration: 520, Train loss: -0.3778, rewards: 0.3981\n",
      "Iteration: 530, Train loss: -0.3139, rewards: 0.3834\n",
      "Iteration: 540, Train loss: -0.2619, rewards: 0.3811\n",
      "Iteration: 550, Train loss: -0.2571, rewards: 0.4172\n",
      "Iteration: 560, Train loss: -0.3261, rewards: 0.3833\n",
      "Iteration: 570, Train loss: -0.3241, rewards: 0.4152\n",
      "Iteration: 580, Train loss: -0.3904, rewards: 0.3995\n",
      "Iteration: 590, Train loss: -0.3507, rewards: 0.3967\n",
      "Iteration: 600, Train loss: -0.3335, rewards: 0.4092\n",
      "Eval:\n",
      "Hits@1: 0.4047, Hits@3: 0.4677, Hits@10: 0.5129, MRR: 0.4417\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.3709, rewards: 0.3902\n",
      "Iteration: 620, Train loss: -0.3501, rewards: 0.3678\n",
      "Iteration: 630, Train loss: -0.3058, rewards: 0.4317\n",
      "Iteration: 640, Train loss: -0.2822, rewards: 0.3812\n",
      "Iteration: 650, Train loss: -0.3650, rewards: 0.4447\n",
      "Iteration: 660, Train loss: -0.3565, rewards: 0.4163\n",
      "Iteration: 670, Train loss: -0.3422, rewards: 0.4192\n",
      "Iteration: 680, Train loss: -0.3771, rewards: 0.4078\n",
      "Iteration: 690, Train loss: -0.3624, rewards: 0.4002\n",
      "Iteration: 700, Train loss: -0.4024, rewards: 0.4123\n",
      "Eval:\n",
      "Hits@1: 0.4163, Hits@3: 0.4631, Hits@10: 0.5105, MRR: 0.4460\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.4065, rewards: 0.4170\n",
      "Iteration: 720, Train loss: -0.3504, rewards: 0.4319\n",
      "Iteration: 730, Train loss: -0.4022, rewards: 0.3825\n",
      "Iteration: 740, Train loss: -0.4215, rewards: 0.3561\n",
      "Iteration: 750, Train loss: -0.3895, rewards: 0.3986\n",
      "Iteration: 760, Train loss: -0.4259, rewards: 0.4081\n",
      "Iteration: 770, Train loss: -0.3908, rewards: 0.4414\n",
      "Iteration: 780, Train loss: -0.3849, rewards: 0.4253\n",
      "Iteration: 790, Train loss: -0.2960, rewards: 0.4061\n",
      "Iteration: 800, Train loss: -0.3422, rewards: 0.3975\n",
      "Eval:\n",
      "Hits@1: 0.4133, Hits@3: 0.4687, Hits@10: 0.5069, MRR: 0.4456\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.3309, rewards: 0.4303\n",
      "Iteration: 820, Train loss: -0.3204, rewards: 0.4189\n",
      "Iteration: 830, Train loss: -0.3600, rewards: 0.4536\n",
      "Iteration: 840, Train loss: -0.3797, rewards: 0.4439\n",
      "Iteration: 850, Train loss: -0.3379, rewards: 0.4408\n",
      "Iteration: 860, Train loss: -0.3866, rewards: 0.4178\n",
      "Iteration: 870, Train loss: -0.3929, rewards: 0.3872\n",
      "Iteration: 880, Train loss: -0.4332, rewards: 0.4437\n",
      "Iteration: 890, Train loss: -0.3942, rewards: 0.4408\n",
      "Iteration: 900, Train loss: -0.3754, rewards: 0.4566\n",
      "Eval:\n",
      "Hits@1: 0.3830, Hits@3: 0.4707, Hits@10: 0.5171, MRR: 0.4326\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.3757, rewards: 0.4097\n",
      "Iteration: 920, Train loss: -0.3669, rewards: 0.4253\n",
      "Iteration: 930, Train loss: -0.3600, rewards: 0.3666\n",
      "Iteration: 940, Train loss: -0.4030, rewards: 0.4223\n",
      "Iteration: 950, Train loss: -0.4004, rewards: 0.4075\n",
      "Iteration: 960, Train loss: -0.3677, rewards: 0.4434\n",
      "Iteration: 970, Train loss: -0.3537, rewards: 0.4169\n",
      "Iteration: 980, Train loss: -0.4353, rewards: 0.4586\n",
      "Iteration: 990, Train loss: -0.3932, rewards: 0.4617\n",
      "Iteration: 1000, Train loss: -0.3399, rewards: 0.4100\n",
      "Eval:\n",
      "Hits@1: 0.3942, Hits@3: 0.4628, Hits@10: 0.5079, MRR: 0.4345\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.3448, rewards: 0.4233\n",
      "Iteration: 1020, Train loss: -0.3437, rewards: 0.4475\n",
      "Iteration: 1030, Train loss: -0.3683, rewards: 0.4214\n",
      "Iteration: 1040, Train loss: -0.3040, rewards: 0.4203\n",
      "Iteration: 1050, Train loss: -0.3834, rewards: 0.4189\n",
      "Iteration: 1060, Train loss: -0.2992, rewards: 0.4259\n",
      "Iteration: 1070, Train loss: -0.3466, rewards: 0.3870\n",
      "Iteration: 1080, Train loss: -0.3487, rewards: 0.3862\n",
      "Iteration: 1090, Train loss: -0.3163, rewards: 0.4602\n",
      "Iteration: 1100, Train loss: -0.3290, rewards: 0.3966\n",
      "Eval:\n",
      "Hits@1: 0.4169, Hits@3: 0.4644, Hits@10: 0.5053, MRR: 0.4459\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.3460, rewards: 0.4303\n",
      "Iteration: 1120, Train loss: -0.3973, rewards: 0.3869\n",
      "Iteration: 1130, Train loss: -0.3722, rewards: 0.4261\n",
      "Iteration: 1140, Train loss: -0.3714, rewards: 0.4198\n",
      "Iteration: 1150, Train loss: -0.3813, rewards: 0.4080\n",
      "Iteration: 1160, Train loss: -0.3502, rewards: 0.4642\n",
      "Iteration: 1170, Train loss: -0.3651, rewards: 0.4025\n",
      "Iteration: 1180, Train loss: -0.3513, rewards: 0.4019\n",
      "Iteration: 1190, Train loss: -0.3343, rewards: 0.4377\n",
      "Iteration: 1200, Train loss: -0.3752, rewards: 0.4156\n",
      "Eval:\n",
      "Hits@1: 0.4021, Hits@3: 0.4743, Hits@10: 0.5142, MRR: 0.4418\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.3889, rewards: 0.4300\n",
      "Iteration: 1220, Train loss: -0.3355, rewards: 0.4228\n",
      "Iteration: 1230, Train loss: -0.3997, rewards: 0.4023\n",
      "Iteration: 1240, Train loss: -0.4404, rewards: 0.4336\n",
      "Iteration: 1250, Train loss: -0.3806, rewards: 0.4292\n",
      "Iteration: 1260, Train loss: -0.3793, rewards: 0.4544\n",
      "Iteration: 1270, Train loss: -0.3600, rewards: 0.4050\n",
      "Iteration: 1280, Train loss: -0.3569, rewards: 0.4148\n",
      "Iteration: 1290, Train loss: -0.3603, rewards: 0.4642\n",
      "Iteration: 1300, Train loss: -0.3632, rewards: 0.4194\n",
      "Eval:\n",
      "Hits@1: 0.4196, Hits@3: 0.4720, Hits@10: 0.5162, MRR: 0.4511\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: -0.3756, rewards: 0.3967\n",
      "Iteration: 1320, Train loss: -0.3154, rewards: 0.4100\n",
      "Iteration: 1330, Train loss: -0.3879, rewards: 0.4088\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1340, Train loss: -0.3504, rewards: 0.4027\n",
      "Iteration: 1350, Train loss: -0.3774, rewards: 0.4439\n",
      "Iteration: 1360, Train loss: -0.4004, rewards: 0.3822\n",
      "Iteration: 1370, Train loss: -0.3833, rewards: 0.4658\n",
      "Iteration: 1380, Train loss: -0.3843, rewards: 0.3895\n",
      "Iteration: 1390, Train loss: -0.3847, rewards: 0.3995\n",
      "Iteration: 1400, Train loss: -0.3755, rewards: 0.3989\n",
      "Eval:\n",
      "Hits@1: 0.4189, Hits@3: 0.4664, Hits@10: 0.5082, MRR: 0.4490\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.3875, rewards: 0.3952\n",
      "Iteration: 1420, Train loss: -0.3863, rewards: 0.4066\n",
      "Iteration: 1430, Train loss: -0.4218, rewards: 0.4669\n",
      "Iteration: 1440, Train loss: -0.3926, rewards: 0.4002\n",
      "Iteration: 1450, Train loss: -0.4417, rewards: 0.4230\n",
      "Iteration: 1460, Train loss: -0.4389, rewards: 0.4267\n",
      "Iteration: 1470, Train loss: -0.4163, rewards: 0.4267\n",
      "Iteration: 1480, Train loss: -0.3865, rewards: 0.4411\n",
      "Iteration: 1490, Train loss: -0.3506, rewards: 0.4378\n",
      "Iteration: 1500, Train loss: -0.3843, rewards: 0.3897\n",
      "Eval:\n",
      "Hits@1: 0.4084, Hits@3: 0.4740, Hits@10: 0.5142, MRR: 0.4457\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.3387, rewards: 0.4736\n",
      "Iteration: 1520, Train loss: -0.3247, rewards: 0.4333\n",
      "Iteration: 1530, Train loss: -0.3994, rewards: 0.4327\n",
      "Iteration: 1540, Train loss: -0.3479, rewards: 0.4239\n",
      "Iteration: 1550, Train loss: -0.3282, rewards: 0.4606\n",
      "Iteration: 1560, Train loss: -0.3579, rewards: 0.4402\n",
      "Iteration: 1570, Train loss: -0.3548, rewards: 0.4148\n",
      "Iteration: 1580, Train loss: -0.3741, rewards: 0.4381\n",
      "Iteration: 1590, Train loss: -0.3719, rewards: 0.4841\n",
      "Iteration: 1600, Train loss: -0.3256, rewards: 0.4144\n",
      "Eval:\n",
      "Hits@1: 0.4219, Hits@3: 0.4733, Hits@10: 0.5168, MRR: 0.4533\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.2792, rewards: 0.4189\n",
      "Iteration: 1620, Train loss: -0.3316, rewards: 0.4288\n",
      "Iteration: 1630, Train loss: -0.3377, rewards: 0.4420\n",
      "Iteration: 1640, Train loss: -0.2801, rewards: 0.4238\n",
      "Iteration: 1650, Train loss: -0.3127, rewards: 0.4491\n",
      "Iteration: 1660, Train loss: -0.3463, rewards: 0.4020\n",
      "Iteration: 1670, Train loss: -0.3312, rewards: 0.4748\n",
      "Iteration: 1680, Train loss: -0.3915, rewards: 0.4231\n",
      "Iteration: 1690, Train loss: -0.3923, rewards: 0.4008\n",
      "Iteration: 1700, Train loss: -0.3432, rewards: 0.4078\n",
      "Eval:\n",
      "Hits@1: 0.4252, Hits@3: 0.4743, Hits@10: 0.5152, MRR: 0.4550\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.3387, rewards: 0.4288\n",
      "Iteration: 1720, Train loss: -0.3707, rewards: 0.5016\n",
      "Iteration: 1730, Train loss: -0.3826, rewards: 0.4297\n",
      "Iteration: 1740, Train loss: -0.3340, rewards: 0.4428\n",
      "Iteration: 1750, Train loss: -0.2915, rewards: 0.4128\n",
      "Iteration: 1760, Train loss: -0.3747, rewards: 0.4102\n",
      "Iteration: 1770, Train loss: -0.4024, rewards: 0.4228\n",
      "Iteration: 1780, Train loss: -0.3639, rewards: 0.3978\n",
      "Iteration: 1790, Train loss: -0.3698, rewards: 0.4614\n",
      "Iteration: 1800, Train loss: -0.3863, rewards: 0.4453\n",
      "Eval:\n",
      "Hits@1: 0.4225, Hits@3: 0.4697, Hits@10: 0.5135, MRR: 0.4519\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.3675, rewards: 0.3814\n",
      "Iteration: 1820, Train loss: -0.3805, rewards: 0.4336\n",
      "Iteration: 1830, Train loss: -0.3845, rewards: 0.3958\n",
      "Iteration: 1840, Train loss: -0.3871, rewards: 0.3833\n",
      "Iteration: 1850, Train loss: -0.3866, rewards: 0.3961\n",
      "Iteration: 1860, Train loss: -0.3889, rewards: 0.3995\n",
      "Iteration: 1870, Train loss: -0.4088, rewards: 0.4148\n",
      "Iteration: 1880, Train loss: -0.4229, rewards: 0.4055\n",
      "Iteration: 1890, Train loss: -0.3903, rewards: 0.4459\n",
      "Iteration: 1900, Train loss: -0.3496, rewards: 0.4487\n",
      "Eval:\n",
      "Hits@1: 0.4146, Hits@3: 0.4700, Hits@10: 0.5162, MRR: 0.4489\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.3381, rewards: 0.3920\n",
      "Iteration: 1920, Train loss: -0.4139, rewards: 0.4030\n",
      "Iteration: 1930, Train loss: -0.3745, rewards: 0.3752\n",
      "Iteration: 1940, Train loss: -0.4216, rewards: 0.4603\n",
      "Iteration: 1950, Train loss: -0.3884, rewards: 0.4464\n",
      "Iteration: 1960, Train loss: -0.4108, rewards: 0.4136\n",
      "Iteration: 1970, Train loss: -0.3353, rewards: 0.4647\n",
      "Iteration: 1980, Train loss: -0.3734, rewards: 0.4380\n",
      "Iteration: 1990, Train loss: -0.3651, rewards: 0.4572\n",
      "Iteration: 2000, Train loss: -0.4005, rewards: 0.4244\n",
      "Eval:\n",
      "Hits@1: 0.4113, Hits@3: 0.4717, Hits@10: 0.5181, MRR: 0.4472\n",
      "------------------------------------------------------------\n",
      "Iteration: 2010, Train loss: -0.4442, rewards: 0.4341\n",
      "Iteration: 2020, Train loss: -0.4088, rewards: 0.4805\n",
      "Iteration: 2030, Train loss: -0.4479, rewards: 0.4209\n",
      "Iteration: 2040, Train loss: -0.4420, rewards: 0.4361\n",
      "Iteration: 2050, Train loss: -0.3889, rewards: 0.4367\n",
      "Iteration: 2060, Train loss: -0.3967, rewards: 0.4333\n",
      "Iteration: 2070, Train loss: -0.3720, rewards: 0.4358\n",
      "Iteration: 2080, Train loss: -0.3619, rewards: 0.4530\n",
      "Iteration: 2090, Train loss: -0.3925, rewards: 0.4247\n",
      "Iteration: 2100, Train loss: -0.4415, rewards: 0.3973\n",
      "Eval:\n",
      "Hits@1: 0.4202, Hits@3: 0.4690, Hits@10: 0.5171, MRR: 0.4518\n",
      "------------------------------------------------------------\n",
      "Iteration: 2110, Train loss: -0.3515, rewards: 0.4353\n",
      "Iteration: 2120, Train loss: -0.3573, rewards: 0.4883\n",
      "Iteration: 2130, Train loss: -0.3993, rewards: 0.4166\n",
      "Iteration: 2140, Train loss: -0.3323, rewards: 0.4355\n",
      "Iteration: 2150, Train loss: -0.3671, rewards: 0.4056\n",
      "Iteration: 2160, Train loss: -0.4294, rewards: 0.4188\n",
      "Iteration: 2170, Train loss: -0.4115, rewards: 0.4338\n",
      "Iteration: 2180, Train loss: -0.3911, rewards: 0.4402\n",
      "Iteration: 2190, Train loss: -0.4535, rewards: 0.4489\n",
      "Iteration: 2200, Train loss: -0.4785, rewards: 0.4052\n",
      "Eval:\n",
      "Hits@1: 0.4127, Hits@3: 0.4750, Hits@10: 0.5201, MRR: 0.4488\n",
      "------------------------------------------------------------\n",
      "Iteration: 2210, Train loss: -0.4064, rewards: 0.4302\n",
      "Iteration: 2220, Train loss: -0.3399, rewards: 0.4330\n",
      "Iteration: 2230, Train loss: -0.4349, rewards: 0.4316\n",
      "Iteration: 2240, Train loss: -0.4040, rewards: 0.4334\n",
      "Iteration: 2250, Train loss: -0.3480, rewards: 0.4481\n",
      "Iteration: 2260, Train loss: -0.4074, rewards: 0.4333\n",
      "Iteration: 2270, Train loss: -0.3516, rewards: 0.4400\n",
      "Iteration: 2280, Train loss: -0.3766, rewards: 0.4498\n",
      "Iteration: 2290, Train loss: -0.3405, rewards: 0.4152\n",
      "Iteration: 2300, Train loss: -0.3274, rewards: 0.4033\n",
      "Eval:\n",
      "Hits@1: 0.4087, Hits@3: 0.4713, Hits@10: 0.5158, MRR: 0.4461\n",
      "------------------------------------------------------------\n",
      "Iteration: 2310, Train loss: -0.3342, rewards: 0.4061\n",
      "Iteration: 2320, Train loss: -0.2829, rewards: 0.4234\n",
      "Iteration: 2330, Train loss: -0.3262, rewards: 0.4378\n",
      "Iteration: 2340, Train loss: -0.3401, rewards: 0.3941\n",
      "Iteration: 2350, Train loss: -0.3243, rewards: 0.4402\n",
      "Iteration: 2360, Train loss: -0.3760, rewards: 0.4248\n",
      "Iteration: 2370, Train loss: -0.3572, rewards: 0.4584\n",
      "Iteration: 2380, Train loss: -0.3806, rewards: 0.4291\n",
      "Iteration: 2390, Train loss: -0.3600, rewards: 0.4020\n",
      "Iteration: 2400, Train loss: -0.3743, rewards: 0.4520\n",
      "Eval:\n",
      "Hits@1: 0.4258, Hits@3: 0.4723, Hits@10: 0.5162, MRR: 0.4554\n",
      "------------------------------------------------------------\n",
      "Iteration: 2410, Train loss: -0.3563, rewards: 0.4159\n",
      "Iteration: 2420, Train loss: -0.3603, rewards: 0.4669\n",
      "Iteration: 2430, Train loss: -0.3585, rewards: 0.4181\n",
      "Iteration: 2440, Train loss: -0.3340, rewards: 0.4406\n",
      "Iteration: 2450, Train loss: -0.3597, rewards: 0.4241\n",
      "Iteration: 2460, Train loss: -0.4232, rewards: 0.4459\n",
      "Iteration: 2470, Train loss: -0.3378, rewards: 0.4502\n",
      "Iteration: 2480, Train loss: -0.3143, rewards: 0.4344\n",
      "Iteration: 2490, Train loss: -0.3096, rewards: 0.3858\n",
      "Iteration: 2500, Train loss: -0.3981, rewards: 0.4314\n",
      "Eval:\n",
      "Hits@1: 0.4235, Hits@3: 0.4713, Hits@10: 0.5158, MRR: 0.4536\n",
      "------------------------------------------------------------\n",
      "Iteration: 2510, Train loss: -0.4142, rewards: 0.4116\n",
      "Iteration: 2520, Train loss: -0.3687, rewards: 0.4348\n",
      "Iteration: 2530, Train loss: -0.3236, rewards: 0.3983\n",
      "Iteration: 2540, Train loss: -0.3347, rewards: 0.4653\n",
      "Iteration: 2550, Train loss: -0.3120, rewards: 0.4861\n",
      "Iteration: 2560, Train loss: -0.3569, rewards: 0.4294\n",
      "Iteration: 2570, Train loss: -0.4020, rewards: 0.4172\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 2580, Train loss: -0.3216, rewards: 0.4537\n",
      "Iteration: 2590, Train loss: -0.4117, rewards: 0.3866\n",
      "Iteration: 2600, Train loss: -0.4125, rewards: 0.4019\n",
      "Eval:\n",
      "Hits@1: 0.4127, Hits@3: 0.4759, Hits@10: 0.5208, MRR: 0.4494\n",
      "------------------------------------------------------------\n",
      "Iteration: 2610, Train loss: -0.3898, rewards: 0.3967\n",
      "Iteration: 2620, Train loss: -0.4115, rewards: 0.4456\n",
      "Iteration: 2630, Train loss: -0.3739, rewards: 0.4022\n",
      "Iteration: 2640, Train loss: -0.3335, rewards: 0.4519\n",
      "Iteration: 2650, Train loss: -0.3507, rewards: 0.4656\n",
      "Iteration: 2660, Train loss: -0.4118, rewards: 0.4203\n",
      "Iteration: 2670, Train loss: -0.3992, rewards: 0.4552\n",
      "Iteration: 2680, Train loss: -0.3457, rewards: 0.4516\n",
      "Iteration: 2690, Train loss: -0.3407, rewards: 0.4827\n",
      "Iteration: 2700, Train loss: -0.3045, rewards: 0.3889\n",
      "Eval:\n",
      "Hits@1: 0.4268, Hits@3: 0.4740, Hits@10: 0.5188, MRR: 0.4570\n",
      "------------------------------------------------------------\n",
      "Iteration: 2710, Train loss: -0.2851, rewards: 0.4722\n",
      "Iteration: 2720, Train loss: -0.3481, rewards: 0.4078\n",
      "Iteration: 2730, Train loss: -0.3482, rewards: 0.3922\n",
      "Iteration: 2740, Train loss: -0.3793, rewards: 0.4645\n",
      "Iteration: 2750, Train loss: -0.3794, rewards: 0.4017\n",
      "Iteration: 2760, Train loss: -0.3830, rewards: 0.4441\n",
      "Iteration: 2770, Train loss: -0.3232, rewards: 0.4644\n",
      "Iteration: 2780, Train loss: -0.3160, rewards: 0.4291\n",
      "Iteration: 2790, Train loss: -0.3194, rewards: 0.4012\n",
      "Iteration: 2800, Train loss: -0.3298, rewards: 0.4472\n",
      "Eval:\n",
      "Hits@1: 0.4212, Hits@3: 0.4687, Hits@10: 0.5181, MRR: 0.4528\n",
      "------------------------------------------------------------\n",
      "Iteration: 2810, Train loss: -0.3723, rewards: 0.3961\n",
      "Iteration: 2820, Train loss: -0.3533, rewards: 0.4144\n",
      "Iteration: 2830, Train loss: -0.3794, rewards: 0.4117\n",
      "Iteration: 2840, Train loss: -0.3717, rewards: 0.4402\n",
      "Iteration: 2850, Train loss: -0.3971, rewards: 0.4342\n",
      "Iteration: 2860, Train loss: -0.3611, rewards: 0.4434\n",
      "Iteration: 2870, Train loss: -0.3338, rewards: 0.4553\n",
      "Iteration: 2880, Train loss: -0.2995, rewards: 0.4133\n",
      "Iteration: 2890, Train loss: -0.3485, rewards: 0.4197\n",
      "Iteration: 2900, Train loss: -0.2437, rewards: 0.4331\n",
      "Eval:\n",
      "Hits@1: 0.4305, Hits@3: 0.4763, Hits@10: 0.5191, MRR: 0.4592\n",
      "------------------------------------------------------------\n",
      "Iteration: 2910, Train loss: -0.3347, rewards: 0.4575\n",
      "Iteration: 2920, Train loss: -0.3163, rewards: 0.4195\n",
      "Iteration: 2930, Train loss: -0.3281, rewards: 0.4059\n",
      "Iteration: 2940, Train loss: -0.3308, rewards: 0.4070\n",
      "Iteration: 2950, Train loss: -0.3976, rewards: 0.4117\n",
      "Iteration: 2960, Train loss: -0.3557, rewards: 0.4825\n",
      "Iteration: 2970, Train loss: -0.3661, rewards: 0.4267\n",
      "Iteration: 2980, Train loss: -0.3888, rewards: 0.4348\n",
      "Iteration: 2990, Train loss: -0.4073, rewards: 0.4692\n",
      "Iteration: 3000, Train loss: -0.3384, rewards: 0.4170\n",
      "Eval:\n",
      "Hits@1: 0.4272, Hits@3: 0.4710, Hits@10: 0.5119, MRR: 0.4552\n",
      "------------------------------------------------------------\n",
      "Iteration: 3010, Train loss: -0.3482, rewards: 0.4144\n",
      "Iteration: 3020, Train loss: -0.3700, rewards: 0.4628\n",
      "Iteration: 3030, Train loss: -0.3284, rewards: 0.4591\n",
      "Iteration: 3040, Train loss: -0.3109, rewards: 0.4161\n",
      "Iteration: 3050, Train loss: -0.3223, rewards: 0.4280\n",
      "Iteration: 3060, Train loss: -0.3473, rewards: 0.4461\n",
      "Iteration: 3070, Train loss: -0.3518, rewards: 0.4098\n",
      "Iteration: 3080, Train loss: -0.3314, rewards: 0.4198\n",
      "Iteration: 3090, Train loss: -0.3444, rewards: 0.4539\n",
      "Iteration: 3100, Train loss: -0.3161, rewards: 0.4753\n",
      "Eval:\n",
      "Hits@1: 0.4262, Hits@3: 0.4773, Hits@10: 0.5188, MRR: 0.4573\n",
      "------------------------------------------------------------\n",
      "Iteration: 3110, Train loss: -0.4396, rewards: 0.4586\n",
      "Iteration: 3120, Train loss: -0.3855, rewards: 0.4302\n",
      "Iteration: 3130, Train loss: -0.3237, rewards: 0.4748\n",
      "Iteration: 3140, Train loss: -0.3756, rewards: 0.4686\n",
      "Iteration: 3150, Train loss: -0.2997, rewards: 0.4631\n",
      "Iteration: 3160, Train loss: -0.3334, rewards: 0.4562\n",
      "Iteration: 3170, Train loss: -0.3594, rewards: 0.4705\n",
      "Iteration: 3180, Train loss: -0.3731, rewards: 0.3891\n",
      "Iteration: 3190, Train loss: -0.3878, rewards: 0.4236\n",
      "Iteration: 3200, Train loss: -0.4018, rewards: 0.4347\n",
      "Eval:\n",
      "Hits@1: 0.4262, Hits@3: 0.4723, Hits@10: 0.5109, MRR: 0.4549\n",
      "------------------------------------------------------------\n",
      "Iteration: 3210, Train loss: -0.4020, rewards: 0.3906\n",
      "Iteration: 3220, Train loss: -0.3796, rewards: 0.3756\n",
      "Iteration: 3230, Train loss: -0.3514, rewards: 0.4027\n",
      "Iteration: 3240, Train loss: -0.3628, rewards: 0.4394\n",
      "Iteration: 3250, Train loss: -0.2731, rewards: 0.4911\n",
      "Iteration: 3260, Train loss: -0.3137, rewards: 0.4230\n",
      "Iteration: 3270, Train loss: -0.3161, rewards: 0.4767\n",
      "Iteration: 3280, Train loss: -0.3536, rewards: 0.4188\n",
      "Iteration: 3290, Train loss: -0.3866, rewards: 0.4511\n",
      "Iteration: 3300, Train loss: -0.3670, rewards: 0.4294\n",
      "Eval:\n",
      "Hits@1: 0.4288, Hits@3: 0.4769, Hits@10: 0.5175, MRR: 0.4583\n",
      "------------------------------------------------------------\n",
      "Iteration: 3310, Train loss: -0.2653, rewards: 0.4170\n",
      "Iteration: 3320, Train loss: -0.3176, rewards: 0.4602\n",
      "Iteration: 3330, Train loss: -0.3302, rewards: 0.4819\n",
      "Iteration: 3340, Train loss: -0.3073, rewards: 0.4467\n",
      "Iteration: 3350, Train loss: -0.2937, rewards: 0.4405\n",
      "Iteration: 3360, Train loss: -0.3076, rewards: 0.4402\n",
      "Iteration: 3370, Train loss: -0.2976, rewards: 0.4266\n",
      "Iteration: 3380, Train loss: -0.3675, rewards: 0.4325\n",
      "Iteration: 3390, Train loss: -0.3529, rewards: 0.3970\n",
      "Iteration: 3400, Train loss: -0.3839, rewards: 0.4245\n",
      "Eval:\n",
      "Hits@1: 0.4034, Hits@3: 0.4726, Hits@10: 0.5112, MRR: 0.4424\n",
      "------------------------------------------------------------\n",
      "Iteration: 3410, Train loss: -0.3375, rewards: 0.4297\n",
      "Iteration: 3420, Train loss: -0.3516, rewards: 0.4150\n",
      "Iteration: 3430, Train loss: -0.3520, rewards: 0.4452\n",
      "Iteration: 3440, Train loss: -0.2942, rewards: 0.4016\n",
      "Iteration: 3450, Train loss: -0.2810, rewards: 0.4703\n",
      "Iteration: 3460, Train loss: -0.2776, rewards: 0.4516\n",
      "Iteration: 3470, Train loss: -0.3320, rewards: 0.4623\n",
      "Iteration: 3480, Train loss: -0.2596, rewards: 0.4309\n",
      "Iteration: 3490, Train loss: -0.2885, rewards: 0.4669\n",
      "Iteration: 3500, Train loss: -0.3397, rewards: 0.4467\n",
      "Eval:\n",
      "Hits@1: 0.4242, Hits@3: 0.4769, Hits@10: 0.5214, MRR: 0.4572\n",
      "------------------------------------------------------------\n",
      "Iteration: 3510, Train loss: -0.3846, rewards: 0.4339\n",
      "Iteration: 3520, Train loss: -0.3130, rewards: 0.3903\n",
      "Iteration: 3530, Train loss: -0.3533, rewards: 0.4645\n",
      "Iteration: 3540, Train loss: -0.3773, rewards: 0.4798\n",
      "Iteration: 3550, Train loss: -0.3526, rewards: 0.4650\n",
      "Iteration: 3560, Train loss: -0.3430, rewards: 0.5156\n",
      "Iteration: 3570, Train loss: -0.3132, rewards: 0.4223\n",
      "Iteration: 3580, Train loss: -0.3154, rewards: 0.4430\n",
      "Iteration: 3590, Train loss: -0.3915, rewards: 0.4300\n",
      "Iteration: 3600, Train loss: -0.3980, rewards: 0.4502\n",
      "Eval:\n",
      "Hits@1: 0.4249, Hits@3: 0.4802, Hits@10: 0.5171, MRR: 0.4574\n",
      "------------------------------------------------------------\n",
      "Iteration: 3610, Train loss: -0.3345, rewards: 0.3633\n",
      "Iteration: 3620, Train loss: -0.2953, rewards: 0.4817\n",
      "Iteration: 3630, Train loss: -0.3232, rewards: 0.3833\n",
      "Iteration: 3640, Train loss: -0.3611, rewards: 0.4059\n",
      "Iteration: 3650, Train loss: -0.3596, rewards: 0.4678\n",
      "Iteration: 3660, Train loss: -0.3484, rewards: 0.4227\n",
      "Iteration: 3670, Train loss: -0.3476, rewards: 0.4291\n",
      "Iteration: 3680, Train loss: -0.3204, rewards: 0.4620\n",
      "Iteration: 3690, Train loss: -0.3112, rewards: 0.4683\n",
      "Iteration: 3700, Train loss: -0.3483, rewards: 0.4459\n",
      "Eval:\n",
      "Hits@1: 0.4245, Hits@3: 0.4815, Hits@10: 0.5234, MRR: 0.4582\n",
      "------------------------------------------------------------\n",
      "Iteration: 3710, Train loss: -0.3748, rewards: 0.4772\n",
      "Iteration: 3720, Train loss: -0.3768, rewards: 0.4420\n",
      "Iteration: 3730, Train loss: -0.3863, rewards: 0.4506\n",
      "Iteration: 3740, Train loss: -0.3791, rewards: 0.4705\n",
      "Iteration: 3750, Train loss: -0.3702, rewards: 0.4302\n",
      "Iteration: 3760, Train loss: -0.3479, rewards: 0.4697\n",
      "Iteration: 3770, Train loss: -0.3336, rewards: 0.4647\n",
      "Iteration: 3780, Train loss: -0.3440, rewards: 0.4459\n",
      "Iteration: 3790, Train loss: -0.2898, rewards: 0.4519\n",
      "Iteration: 3800, Train loss: -0.2997, rewards: 0.4311\n",
      "Eval:\n",
      "Hits@1: 0.4262, Hits@3: 0.4806, Hits@10: 0.5231, MRR: 0.4581\n",
      "------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 3810, Train loss: -0.3157, rewards: 0.4527\n",
      "Iteration: 3820, Train loss: -0.2862, rewards: 0.4084\n",
      "Iteration: 3830, Train loss: -0.2955, rewards: 0.4602\n",
      "Iteration: 3840, Train loss: -0.3172, rewards: 0.4758\n",
      "Iteration: 3850, Train loss: -0.3194, rewards: 0.4903\n",
      "Iteration: 3860, Train loss: -0.3077, rewards: 0.4630\n",
      "Iteration: 3870, Train loss: -0.3230, rewards: 0.4605\n",
      "Iteration: 3880, Train loss: -0.3530, rewards: 0.4419\n",
      "Iteration: 3890, Train loss: -0.2981, rewards: 0.4119\n",
      "Iteration: 3900, Train loss: -0.3337, rewards: 0.4622\n",
      "Eval:\n",
      "Hits@1: 0.4291, Hits@3: 0.4832, Hits@10: 0.5260, MRR: 0.4619\n",
      "------------------------------------------------------------\n",
      "Iteration: 3910, Train loss: -0.3590, rewards: 0.4263\n",
      "Iteration: 3920, Train loss: -0.3312, rewards: 0.4634\n",
      "Iteration: 3930, Train loss: -0.3386, rewards: 0.4372\n",
      "Iteration: 3940, Train loss: -0.2995, rewards: 0.4458\n",
      "Iteration: 3950, Train loss: -0.3578, rewards: 0.4520\n",
      "Iteration: 3960, Train loss: -0.3094, rewards: 0.4650\n",
      "Iteration: 3970, Train loss: -0.2592, rewards: 0.4927\n",
      "Iteration: 3980, Train loss: -0.3076, rewards: 0.4736\n",
      "Iteration: 3990, Train loss: -0.3044, rewards: 0.4523\n",
      "Iteration: 4000, Train loss: -0.3022, rewards: 0.4484\n",
      "Eval:\n",
      "Hits@1: 0.3955, Hits@3: 0.4759, Hits@10: 0.5247, MRR: 0.4424\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "from model.ours2 import *\n",
    "params = set_params(2, 32, 0.05, 0.05, 0.0001)\n",
    "trainer = Trainer(params)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "897867ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.4333, Hits@3: 0.4844, Hits@10: 0.5281, MRR: 0.4653\n"
     ]
    }
   ],
   "source": [
    "trainer.agent.eval()\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "test_results = trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "833366c6",
   "metadata": {},
   "source": [
    "# inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e09a25f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "with open('./outputs_WN18RR-tune2/results_table.pk5', 'rb') as f:\n",
    "    results = pickle5.load(f)\n",
    "check = pd.DataFrame([[k, float(v.split(': ')[-1])] for k,v in results.items()], columns = ['config', 'mrr'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "29ad0aea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>config</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>1-64-0.08-0.08-0.0001</td>\n",
       "      <td>0.4659</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1-32-0.05-0.05-5e-05</td>\n",
       "      <td>0.4660</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>111</th>\n",
       "      <td>2-64-0.08-0.08-0.0001</td>\n",
       "      <td>0.4673</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>1-32-0.08-0.08-0.0001</td>\n",
       "      <td>0.4678</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>101</th>\n",
       "      <td>2-128-0.05-0.05-0.0001</td>\n",
       "      <td>0.4678</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>1-128-0.05-0.05-0.0005</td>\n",
       "      <td>0.4679</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>1-128-0.08-0.08-0.0001</td>\n",
       "      <td>0.4693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1-32-0.05-0.05-0.0001</td>\n",
       "      <td>0.4694</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>91</th>\n",
       "      <td>2-32-0.05-0.05-0.0001</td>\n",
       "      <td>0.4695</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>2-64-0.05-0.05-5e-05</td>\n",
       "      <td>0.4711</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     config     mrr\n",
       "21    1-64-0.08-0.08-0.0001  0.4659\n",
       "3      1-32-0.05-0.05-5e-05  0.4660\n",
       "111   2-64-0.08-0.08-0.0001  0.4673\n",
       "16    1-32-0.08-0.08-0.0001  0.4678\n",
       "101  2-128-0.05-0.05-0.0001  0.4678\n",
       "10   1-128-0.05-0.05-0.0005  0.4679\n",
       "26   1-128-0.08-0.08-0.0001  0.4693\n",
       "1     1-32-0.05-0.05-0.0001  0.4694\n",
       "91    2-32-0.05-0.05-0.0001  0.4695\n",
       "98     2-64-0.05-0.05-5e-05  0.4711"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "check.sort_values('mrr').tail(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b8c9442f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_params(LSTM_layers = 1, batch_size = 8, beta = 0.15, Lambda = 0.15, learning_rate = 5e-5):\n",
    "    options = {}\n",
    "\n",
    "    #basic setting\n",
    "    options['use_cuda'] = True\n",
    "    options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/vocab/'\n",
    "    options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/'\n",
    "    options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "    options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "    options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "    options['model_dir'] = './outputs_WN18RR-tune2/'\n",
    "    options['output_dir'] = './outputs_WN18RR-tune2/'\n",
    "\n",
    "    #agent setting\n",
    "    options['pretrained_embeddings_relation'] = {}\n",
    "    options['pretrained_embeddings_entity'] = {}\n",
    "    options['embedding_size'] = 50\n",
    "    options['hidden_size'] = 200\n",
    "    options['use_entity_embeddings'] = 1\n",
    "    options['train_entity_embeddings'] = 1\n",
    "    options['train_relation_embeddings'] = 1\n",
    "    options['path_length'] = 3\n",
    "    options['LSTM_layers'] = LSTM_layers\n",
    "    options['max_num_actions'] = 100\n",
    "    options['gnn_layer'] = 1\n",
    "\n",
    "    #hyperparameters\n",
    "    options['test_rollouts'] = 100\n",
    "    options['num_rollouts'] = 20\n",
    "    options['batch_size'] = batch_size\n",
    "    options['eval_batch_size'] = 8\n",
    "    options['beta'] = beta\n",
    "    options['Lambda'] = Lambda\n",
    "    options['gamma'] = 1\n",
    "    options['positive_reward'] = 1\n",
    "    options['negative_reward'] = 0\n",
    "    options['learning_rate'] = learning_rate\n",
    "    options['grad_clip_norm'] = 100\n",
    "    options['eval_every'] = 100\n",
    "    options['total_iterations'] = 2000*(64/batch_size)\n",
    "    options['pool'] = 'max'\n",
    "    \n",
    "    return options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bceb168",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/Ours/model/ours2.py:334: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.2029, rewards: 0.1583\n",
      "Iteration: 20, Train loss: -0.2739, rewards: 0.2741\n",
      "Iteration: 30, Train loss: -0.3290, rewards: 0.2641\n",
      "Iteration: 40, Train loss: -0.3506, rewards: 0.2992\n"
     ]
    }
   ],
   "source": [
    "from model.ours2 import *\n",
    "params = set_params(2, 32, 0.05, 0.05, 0.0001)\n",
    "trainer = Trainer(params)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aca32b10",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.agent.eval()\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "test_results = trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13881e64",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
