{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-6abedaa4-16cd-51b2-9b2f-043073ed897a\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2e28869",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_params(LSTM_layers = 1, batch_size = 8, beta = 0.15, Lambda = 0.15, learning_rate = 5e-5):\n",
    "    options = {}\n",
    "\n",
    "    #basic setting\n",
    "    options['use_cuda'] = True\n",
    "    options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/FB15K-237/vocab/'\n",
    "    options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/FB15K-237/'\n",
    "    options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "    options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "    options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "    options['model_dir'] = './outputs_FB15K-237_v7-tune2/'\n",
    "    options['output_dir'] = './outputs_FB15K-237_v7-tune2/'\n",
    "\n",
    "    #agent setting\n",
    "    options['pretrained_embeddings_relation'] = {}\n",
    "    options['pretrained_embeddings_entity'] = {}\n",
    "    options['embedding_size'] = 50\n",
    "    options['hidden_size'] = 200\n",
    "    options['use_entity_embeddings'] = 1\n",
    "    options['train_entity_embeddings'] = 0\n",
    "    options['train_relation_embeddings'] = 1\n",
    "    options['path_length'] = 3\n",
    "    options['LSTM_layers'] = LSTM_layers\n",
    "    options['max_num_actions'] = 100\n",
    "\n",
    "    #hyperparameters\n",
    "    options['test_rollouts'] = 100\n",
    "    options['num_rollouts'] = 20\n",
    "    options['batch_size'] = batch_size\n",
    "    options['eval_batch_size'] = 12\n",
    "    options['beta'] = beta\n",
    "    options['Lambda'] = Lambda\n",
    "    options['gamma'] = 1\n",
    "    options['positive_reward'] = 1\n",
    "    options['negative_reward'] = 0\n",
    "    options['learning_rate'] = learning_rate\n",
    "    options['grad_clip_norm'] = 100\n",
    "    options['eval_every'] = 200\n",
    "    options['total_iterations'] = 2000*(64/batch_size)\n",
    "    options['pool'] = 'max'\n",
    "    \n",
    "    return options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc64eafe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/Ours/model/ours.py:333: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 20, Train loss: -0.4368, rewards: 0.0161\n",
      "Iteration: 40, Train loss: -0.4059, rewards: 0.0114\n",
      "Iteration: 60, Train loss: -0.4102, rewards: 0.0117\n",
      "Iteration: 80, Train loss: -0.4531, rewards: 0.0208\n",
      "Iteration: 100, Train loss: -0.4281, rewards: 0.0136\n",
      "Iteration: 120, Train loss: -0.4309, rewards: 0.0120\n",
      "Iteration: 140, Train loss: -0.4480, rewards: 0.0172\n",
      "Iteration: 160, Train loss: -0.4443, rewards: 0.0167\n",
      "Iteration: 180, Train loss: -0.4268, rewards: 0.0173\n",
      "Iteration: 200, Train loss: -0.4514, rewards: 0.0178\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/Research/GraphRL/Ours/model/ours.py:635: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.0063, Hits@3: 0.0203, Hits@10: 0.0625, MRR: 0.0249\n",
      "------------------------------------------------------------\n",
      "Iteration: 220, Train loss: -0.4569, rewards: 0.0205\n",
      "Iteration: 240, Train loss: -0.4422, rewards: 0.0187\n",
      "Iteration: 260, Train loss: -0.4333, rewards: 0.0166\n",
      "Iteration: 280, Train loss: -0.4228, rewards: 0.0170\n",
      "Iteration: 300, Train loss: -0.4247, rewards: 0.0130\n",
      "Iteration: 320, Train loss: -0.4365, rewards: 0.0145\n",
      "Iteration: 340, Train loss: -0.4463, rewards: 0.0178\n",
      "Iteration: 360, Train loss: -0.4512, rewards: 0.0181\n",
      "Iteration: 380, Train loss: -0.4427, rewards: 0.0170\n",
      "Iteration: 400, Train loss: -0.4493, rewards: 0.0220\n",
      "Eval:\n",
      "Hits@1: 0.0068, Hits@3: 0.0245, Hits@10: 0.0706, MRR: 0.0277\n",
      "------------------------------------------------------------\n",
      "Iteration: 420, Train loss: -0.4499, rewards: 0.0200\n",
      "Iteration: 440, Train loss: -0.4487, rewards: 0.0187\n",
      "Iteration: 460, Train loss: -0.4758, rewards: 0.0248\n",
      "Iteration: 480, Train loss: -0.4513, rewards: 0.0214\n",
      "Iteration: 500, Train loss: -0.4193, rewards: 0.0166\n",
      "Iteration: 520, Train loss: -0.4419, rewards: 0.0208\n",
      "Iteration: 540, Train loss: -0.4953, rewards: 0.0266\n",
      "Iteration: 560, Train loss: -0.4388, rewards: 0.0184\n",
      "Iteration: 580, Train loss: -0.4527, rewards: 0.0203\n",
      "Iteration: 600, Train loss: -0.4592, rewards: 0.0189\n",
      "Eval:\n",
      "Hits@1: 0.0092, Hits@3: 0.0277, Hits@10: 0.0818, MRR: 0.0319\n",
      "------------------------------------------------------------\n",
      "Iteration: 620, Train loss: -0.4658, rewards: 0.0242\n",
      "Iteration: 640, Train loss: -0.4565, rewards: 0.0244\n",
      "Iteration: 660, Train loss: -0.4561, rewards: 0.0227\n",
      "Iteration: 680, Train loss: -0.4898, rewards: 0.0258\n",
      "Iteration: 700, Train loss: -0.4683, rewards: 0.0231\n",
      "Iteration: 720, Train loss: -0.4438, rewards: 0.0183\n",
      "Iteration: 740, Train loss: -0.4384, rewards: 0.0170\n",
      "Iteration: 760, Train loss: -0.4988, rewards: 0.0253\n",
      "Iteration: 780, Train loss: -0.4384, rewards: 0.0167\n",
      "Iteration: 800, Train loss: -0.4569, rewards: 0.0200\n",
      "Eval:\n",
      "Hits@1: 0.0105, Hits@3: 0.0345, Hits@10: 0.0935, MRR: 0.0379\n",
      "------------------------------------------------------------\n",
      "Iteration: 820, Train loss: -0.4442, rewards: 0.0145\n",
      "Iteration: 840, Train loss: -0.4394, rewards: 0.0186\n",
      "Iteration: 860, Train loss: -0.4594, rewards: 0.0238\n",
      "Iteration: 880, Train loss: -0.4688, rewards: 0.0209\n",
      "Iteration: 900, Train loss: -0.4724, rewards: 0.0283\n",
      "Iteration: 920, Train loss: -0.4596, rewards: 0.0272\n",
      "Iteration: 940, Train loss: -0.4729, rewards: 0.0267\n",
      "Iteration: 960, Train loss: -0.4621, rewards: 0.0206\n",
      "Iteration: 980, Train loss: -0.4596, rewards: 0.0211\n",
      "Iteration: 1000, Train loss: -0.4763, rewards: 0.0316\n",
      "Eval:\n",
      "Hits@1: 0.0183, Hits@3: 0.0450, Hits@10: 0.1185, MRR: 0.0483\n",
      "------------------------------------------------------------\n",
      "Iteration: 1020, Train loss: -0.4601, rewards: 0.0217\n",
      "Iteration: 1040, Train loss: -0.4928, rewards: 0.0338\n",
      "Iteration: 1060, Train loss: -0.4763, rewards: 0.0234\n",
      "Iteration: 1080, Train loss: -0.5094, rewards: 0.0297\n",
      "Iteration: 1100, Train loss: -0.4683, rewards: 0.0258\n",
      "Iteration: 1120, Train loss: -0.4430, rewards: 0.0194\n",
      "Iteration: 1140, Train loss: -0.4920, rewards: 0.0275\n",
      "Iteration: 1160, Train loss: -0.4501, rewards: 0.0170\n",
      "Iteration: 1180, Train loss: -0.4742, rewards: 0.0206\n",
      "Iteration: 1200, Train loss: -0.4516, rewards: 0.0252\n",
      "Eval:\n",
      "Hits@1: 0.0274, Hits@3: 0.0576, Hits@10: 0.1394, MRR: 0.0615\n",
      "------------------------------------------------------------\n",
      "Iteration: 1220, Train loss: -0.5060, rewards: 0.0323\n",
      "Iteration: 1240, Train loss: -0.4821, rewards: 0.0234\n",
      "Iteration: 1260, Train loss: -0.4618, rewards: 0.0223\n",
      "Iteration: 1280, Train loss: -0.4882, rewards: 0.0347\n",
      "Iteration: 1300, Train loss: -0.4816, rewards: 0.0233\n",
      "Iteration: 1320, Train loss: -0.5158, rewards: 0.0341\n",
      "Iteration: 1340, Train loss: -0.4605, rewards: 0.0258\n",
      "Iteration: 1360, Train loss: -0.4927, rewards: 0.0325\n",
      "Iteration: 1380, Train loss: -0.4517, rewards: 0.0192\n",
      "Iteration: 1400, Train loss: -0.4957, rewards: 0.0306\n",
      "Eval:\n",
      "Hits@1: 0.0329, Hits@3: 0.0668, Hits@10: 0.1551, MRR: 0.0694\n",
      "------------------------------------------------------------\n",
      "Iteration: 1420, Train loss: -0.4935, rewards: 0.0344\n",
      "Iteration: 1440, Train loss: -0.4716, rewards: 0.0275\n",
      "Iteration: 1460, Train loss: -0.4992, rewards: 0.0278\n",
      "Iteration: 1480, Train loss: -0.5158, rewards: 0.0416\n",
      "Iteration: 1500, Train loss: -0.4702, rewards: 0.0300\n",
      "Iteration: 1520, Train loss: -0.5294, rewards: 0.0484\n",
      "Iteration: 1540, Train loss: -0.5016, rewards: 0.0302\n",
      "Iteration: 1560, Train loss: -0.4847, rewards: 0.0312\n",
      "Iteration: 1580, Train loss: -0.4920, rewards: 0.0342\n",
      "Iteration: 1600, Train loss: -0.4816, rewards: 0.0306\n",
      "Eval:\n",
      "Hits@1: 0.0554, Hits@3: 0.0976, Hits@10: 0.1941, MRR: 0.0965\n",
      "------------------------------------------------------------\n",
      "Iteration: 1620, Train loss: -0.4914, rewards: 0.0305\n",
      "Iteration: 1640, Train loss: -0.4900, rewards: 0.0262\n",
      "Iteration: 1660, Train loss: -0.5178, rewards: 0.0322\n",
      "Iteration: 1680, Train loss: -0.5098, rewards: 0.0397\n",
      "Iteration: 1700, Train loss: -0.4912, rewards: 0.0391\n",
      "Iteration: 1720, Train loss: -0.4936, rewards: 0.0398\n",
      "Iteration: 1740, Train loss: -0.5118, rewards: 0.0355\n",
      "Iteration: 1760, Train loss: -0.5543, rewards: 0.0492\n",
      "Iteration: 1780, Train loss: -0.5418, rewards: 0.0433\n",
      "Iteration: 1800, Train loss: -0.5139, rewards: 0.0355\n",
      "Eval:\n",
      "Hits@1: 0.0743, Hits@3: 0.1299, Hits@10: 0.2334, MRR: 0.1222\n",
      "------------------------------------------------------------\n",
      "Iteration: 1820, Train loss: -0.5148, rewards: 0.0336\n",
      "Iteration: 1840, Train loss: -0.5234, rewards: 0.0403\n",
      "Iteration: 1860, Train loss: -0.5368, rewards: 0.0414\n",
      "Iteration: 1880, Train loss: -0.5106, rewards: 0.0350\n",
      "Iteration: 1900, Train loss: -0.5199, rewards: 0.0431\n",
      "Iteration: 1920, Train loss: -0.5684, rewards: 0.0502\n",
      "Iteration: 1940, Train loss: -0.5467, rewards: 0.0466\n",
      "Iteration: 1960, Train loss: -0.5378, rewards: 0.0578\n",
      "Iteration: 1980, Train loss: -0.5301, rewards: 0.0503\n",
      "Iteration: 2000, Train loss: -0.5139, rewards: 0.0483\n",
      "Eval:\n",
      "Hits@1: 0.0904, Hits@3: 0.1518, Hits@10: 0.2593, MRR: 0.1420\n",
      "------------------------------------------------------------\n",
      "Iteration: 2020, Train loss: -0.5783, rewards: 0.0647\n",
      "Iteration: 2040, Train loss: -0.5481, rewards: 0.0511\n",
      "Iteration: 2060, Train loss: -0.5262, rewards: 0.0452\n",
      "Iteration: 2080, Train loss: -0.5700, rewards: 0.0627\n",
      "Iteration: 2100, Train loss: -0.5273, rewards: 0.0494\n",
      "Iteration: 2120, Train loss: -0.5188, rewards: 0.0422\n",
      "Iteration: 2140, Train loss: -0.5695, rewards: 0.0547\n",
      "Iteration: 2160, Train loss: -0.5462, rewards: 0.0556\n",
      "Iteration: 2180, Train loss: -0.5482, rewards: 0.0514\n",
      "Iteration: 2200, Train loss: -0.5757, rewards: 0.0516\n",
      "Eval:\n",
      "Hits@1: 0.0954, Hits@3: 0.1632, Hits@10: 0.2616, MRR: 0.1482\n",
      "------------------------------------------------------------\n",
      "Iteration: 2220, Train loss: -0.5436, rewards: 0.0552\n",
      "Iteration: 2240, Train loss: -0.5924, rewards: 0.0564\n",
      "Iteration: 2260, Train loss: -0.5324, rewards: 0.0456\n",
      "Iteration: 2280, Train loss: -0.5317, rewards: 0.0486\n",
      "Iteration: 2300, Train loss: -0.5734, rewards: 0.0667\n",
      "Iteration: 2320, Train loss: -0.5855, rewards: 0.0631\n",
      "Iteration: 2340, Train loss: -0.5737, rewards: 0.0639\n",
      "Iteration: 2360, Train loss: -0.5358, rewards: 0.0450\n",
      "Iteration: 2380, Train loss: -0.5626, rewards: 0.0570\n",
      "Iteration: 2400, Train loss: -0.5821, rewards: 0.0773\n",
      "Eval:\n",
      "Hits@1: 0.1081, Hits@3: 0.1905, Hits@10: 0.2934, MRR: 0.1675\n",
      "------------------------------------------------------------\n",
      "Iteration: 2420, Train loss: -0.5546, rewards: 0.0537\n",
      "Iteration: 2440, Train loss: -0.6331, rewards: 0.0672\n",
      "Iteration: 2460, Train loss: -0.6004, rewards: 0.0769\n",
      "Iteration: 2480, Train loss: -0.6356, rewards: 0.0881\n",
      "Iteration: 2500, Train loss: -0.5545, rewards: 0.0572\n",
      "Iteration: 2520, Train loss: -0.5624, rewards: 0.0645\n",
      "Iteration: 2540, Train loss: -0.5853, rewards: 0.0648\n",
      "Iteration: 2560, Train loss: -0.6478, rewards: 0.0814\n",
      "Iteration: 2580, Train loss: -0.6128, rewards: 0.0714\n",
      "Iteration: 2600, Train loss: -0.5540, rewards: 0.0602\n",
      "Eval:\n",
      "Hits@1: 0.1294, Hits@3: 0.2011, Hits@10: 0.3030, MRR: 0.1836\n",
      "------------------------------------------------------------\n",
      "Iteration: 2620, Train loss: -0.5883, rewards: 0.0705\n",
      "Iteration: 2640, Train loss: -0.5737, rewards: 0.0650\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 2660, Train loss: -0.5657, rewards: 0.0552\n",
      "Iteration: 2680, Train loss: -0.5801, rewards: 0.0645\n",
      "Iteration: 2700, Train loss: -0.5800, rewards: 0.0627\n",
      "Iteration: 2720, Train loss: -0.5901, rewards: 0.0664\n",
      "Iteration: 2740, Train loss: -0.5685, rewards: 0.0611\n",
      "Iteration: 2760, Train loss: -0.6148, rewards: 0.0730\n",
      "Iteration: 2780, Train loss: -0.5928, rewards: 0.0733\n",
      "Iteration: 2800, Train loss: -0.5841, rewards: 0.0644\n",
      "Eval:\n",
      "Hits@1: 0.1376, Hits@3: 0.2096, Hits@10: 0.3117, MRR: 0.1919\n",
      "------------------------------------------------------------\n",
      "Iteration: 2820, Train loss: -0.5492, rewards: 0.0486\n",
      "Iteration: 2840, Train loss: -0.6493, rewards: 0.0714\n",
      "Iteration: 2860, Train loss: -0.5839, rewards: 0.0628\n",
      "Iteration: 2880, Train loss: -0.5765, rewards: 0.0758\n",
      "Iteration: 2900, Train loss: -0.5934, rewards: 0.0656\n",
      "Iteration: 2920, Train loss: -0.6306, rewards: 0.0872\n",
      "Iteration: 2940, Train loss: -0.5759, rewards: 0.0733\n",
      "Iteration: 2960, Train loss: -0.5891, rewards: 0.0750\n",
      "Iteration: 2980, Train loss: -0.6170, rewards: 0.0664\n",
      "Iteration: 3000, Train loss: -0.6283, rewards: 0.0808\n",
      "Eval:\n",
      "Hits@1: 0.1473, Hits@3: 0.2258, Hits@10: 0.3252, MRR: 0.2045\n",
      "------------------------------------------------------------\n",
      "Iteration: 3020, Train loss: -0.5915, rewards: 0.0759\n",
      "Iteration: 3040, Train loss: -0.5881, rewards: 0.0688\n",
      "Iteration: 3060, Train loss: -0.6120, rewards: 0.0878\n",
      "Iteration: 3080, Train loss: -0.6156, rewards: 0.0777\n",
      "Iteration: 3100, Train loss: -0.5936, rewards: 0.0747\n",
      "Iteration: 3120, Train loss: -0.6203, rewards: 0.0881\n",
      "Iteration: 3140, Train loss: -0.5734, rewards: 0.0678\n",
      "Iteration: 3160, Train loss: -0.5837, rewards: 0.0653\n",
      "Iteration: 3180, Train loss: -0.5481, rewards: 0.0717\n",
      "Iteration: 3200, Train loss: -0.5918, rewards: 0.0658\n",
      "Eval:\n",
      "Hits@1: 0.1514, Hits@3: 0.2303, Hits@10: 0.3348, MRR: 0.2091\n",
      "------------------------------------------------------------\n",
      "Iteration: 3220, Train loss: -0.5910, rewards: 0.0825\n",
      "Iteration: 3240, Train loss: -0.5664, rewards: 0.0714\n",
      "Iteration: 3260, Train loss: -0.6159, rewards: 0.0784\n",
      "Iteration: 3280, Train loss: -0.6262, rewards: 0.0827\n",
      "Iteration: 3300, Train loss: -0.6083, rewards: 0.0730\n",
      "Iteration: 3320, Train loss: -0.6070, rewards: 0.0892\n",
      "Iteration: 3340, Train loss: -0.6501, rewards: 0.0942\n",
      "Iteration: 3360, Train loss: -0.6478, rewards: 0.1034\n",
      "Iteration: 3380, Train loss: -0.6297, rewards: 0.0862\n",
      "Iteration: 3400, Train loss: -0.5841, rewards: 0.0814\n",
      "Eval:\n",
      "Hits@1: 0.1620, Hits@3: 0.2410, Hits@10: 0.3410, MRR: 0.2192\n",
      "------------------------------------------------------------\n",
      "Iteration: 3420, Train loss: -0.6764, rewards: 0.1033\n",
      "Iteration: 3440, Train loss: -0.6048, rewards: 0.0889\n",
      "Iteration: 3460, Train loss: -0.6224, rewards: 0.0784\n",
      "Iteration: 3480, Train loss: -0.6424, rewards: 0.0912\n",
      "Iteration: 3500, Train loss: -0.6132, rewards: 0.0795\n",
      "Iteration: 3520, Train loss: -0.6396, rewards: 0.0941\n",
      "Iteration: 3540, Train loss: -0.6595, rewards: 0.0902\n",
      "Iteration: 3560, Train loss: -0.6040, rewards: 0.0772\n",
      "Iteration: 3580, Train loss: -0.6386, rewards: 0.0880\n",
      "Iteration: 3600, Train loss: -0.6108, rewards: 0.0791\n",
      "Eval:\n",
      "Hits@1: 0.1644, Hits@3: 0.2439, Hits@10: 0.3425, MRR: 0.2217\n",
      "------------------------------------------------------------\n",
      "Iteration: 3620, Train loss: -0.5904, rewards: 0.0914\n",
      "Iteration: 3640, Train loss: -0.6488, rewards: 0.0858\n",
      "Iteration: 3660, Train loss: -0.6376, rewards: 0.1025\n",
      "Iteration: 3680, Train loss: -0.6340, rewards: 0.1003\n",
      "Iteration: 3700, Train loss: -0.6054, rewards: 0.0677\n",
      "Iteration: 3720, Train loss: -0.6112, rewards: 0.0909\n",
      "Iteration: 3740, Train loss: -0.5839, rewards: 0.0800\n",
      "Iteration: 3760, Train loss: -0.6108, rewards: 0.0884\n",
      "Iteration: 3780, Train loss: -0.6482, rewards: 0.0988\n",
      "Iteration: 3800, Train loss: -0.6356, rewards: 0.0934\n",
      "Eval:\n",
      "Hits@1: 0.1704, Hits@3: 0.2492, Hits@10: 0.3445, MRR: 0.2271\n",
      "------------------------------------------------------------\n",
      "Iteration: 3820, Train loss: -0.6031, rewards: 0.0775\n",
      "Iteration: 3840, Train loss: -0.6290, rewards: 0.0828\n",
      "Iteration: 3860, Train loss: -0.6375, rewards: 0.0989\n",
      "Iteration: 3880, Train loss: -0.6508, rewards: 0.1002\n",
      "Iteration: 3900, Train loss: -0.6639, rewards: 0.0998\n",
      "Iteration: 3920, Train loss: -0.6571, rewards: 0.0920\n",
      "Iteration: 3940, Train loss: -0.6864, rewards: 0.1042\n",
      "Iteration: 3960, Train loss: -0.6244, rewards: 0.0927\n",
      "Iteration: 3980, Train loss: -0.5866, rewards: 0.0711\n",
      "Iteration: 4000, Train loss: -0.6681, rewards: 0.1123\n",
      "Eval:\n",
      "Hits@1: 0.1727, Hits@3: 0.2511, Hits@10: 0.3509, MRR: 0.2298\n",
      "------------------------------------------------------------\n",
      "Iteration: 4020, Train loss: -0.6563, rewards: 0.1166\n",
      "Iteration: 4040, Train loss: -0.6250, rewards: 0.0797\n",
      "Iteration: 4060, Train loss: -0.6902, rewards: 0.0920\n",
      "Iteration: 4080, Train loss: -0.6391, rewards: 0.0855\n",
      "Iteration: 4100, Train loss: -0.6414, rewards: 0.1039\n",
      "Iteration: 4120, Train loss: -0.6857, rewards: 0.1022\n",
      "Iteration: 4140, Train loss: -0.6451, rewards: 0.0900\n",
      "Iteration: 4160, Train loss: -0.6522, rewards: 0.0977\n",
      "Iteration: 4180, Train loss: -0.6026, rewards: 0.0869\n",
      "Iteration: 4200, Train loss: -0.6499, rewards: 0.0966\n",
      "Eval:\n",
      "Hits@1: 0.1694, Hits@3: 0.2481, Hits@10: 0.3422, MRR: 0.2254\n",
      "------------------------------------------------------------\n",
      "Iteration: 4220, Train loss: -0.6546, rewards: 0.0995\n",
      "Iteration: 4240, Train loss: -0.6114, rewards: 0.0925\n",
      "Iteration: 4260, Train loss: -0.6361, rewards: 0.0891\n",
      "Iteration: 4280, Train loss: -0.6727, rewards: 0.0970\n",
      "Iteration: 4300, Train loss: -0.6486, rewards: 0.1163\n",
      "Iteration: 4320, Train loss: -0.6649, rewards: 0.0977\n",
      "Iteration: 4340, Train loss: -0.6415, rewards: 0.0928\n",
      "Iteration: 4360, Train loss: -0.6340, rewards: 0.0938\n",
      "Iteration: 4380, Train loss: -0.6788, rewards: 0.1072\n",
      "Iteration: 4400, Train loss: -0.6432, rewards: 0.1089\n",
      "Eval:\n"
     ]
    }
   ],
   "source": [
    "from model.ours import *\n",
    "with open('./outputs_FB15K-237_v7-tune2/results_table.pk5', 'rb') as f:\n",
    "    results = pickle5.load(f)\n",
    "\n",
    "for layer in [1, 2]:\n",
    "    for bl in [0.1, 0.08, 0.12, 0.05, 0.15, 0.2, 0.25]:\n",
    "        for bs in [8, 16, 32]:\n",
    "            for lr in [1e-5, 5e-5]:\n",
    "                params = set_params(layer, bs, bl, bl, lr)\n",
    "                name = f'{layer}-{bs}-{bl}-{bl}-{lr}'\n",
    "                if results.get(name) is not None:\n",
    "                    continue\n",
    "\n",
    "                trainer = Trainer(params)\n",
    "                trainer.train()\n",
    "                torch.cuda.empty_cache()\n",
    "\n",
    "                trainer.agent.load_state_dict(torch.load(params['model_dir'] + 'agent.ckpt'))\n",
    "                trainer.agent.eval()\n",
    "                trainer.test_environment = trainer.test_test_environment\n",
    "                tmp = trainer.test(beam=True, print_paths=False, save_model=False)\n",
    "\n",
    "                print(name)\n",
    "                print(tmp)\n",
    "                print('-------------')\n",
    "                results[name] = tmp\n",
    "\n",
    "                with open(params['output_dir'] + 'results_table.pk5', 'wb') as f:\n",
    "                    pickle5.dump(results, f)\n",
    "                    \n",
    "                del trainer\n",
    "                gc.collect()\n",
    "                torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
