{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-dbef6f87-fd33-5950-84e0-7007d974c9ac\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2e28869",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_params(LSTM_layers = 1, batch_size = 8, beta = 0.15, Lambda = 0.15, learning_rate = 5e-5):\n",
    "    options = {}\n",
    "\n",
    "    #basic setting\n",
    "    options['use_cuda'] = True\n",
    "    options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/nell/vocab/'\n",
    "    options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/nell/'\n",
    "    options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "    options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "    options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "    options['model_dir'] = './outputs_NELL-995_v7-tune2/'\n",
    "    options['output_dir'] = './outputs_NELL-995_v7-tune2/'\n",
    "\n",
    "    #agent setting\n",
    "    options['pretrained_embeddings_relation'] = {}\n",
    "    options['pretrained_embeddings_entity'] = {}\n",
    "    options['embedding_size'] = 50\n",
    "    options['hidden_size'] = 200\n",
    "    options['use_entity_embeddings'] = 1\n",
    "    options['train_entity_embeddings'] = 1\n",
    "    options['train_relation_embeddings'] = 1\n",
    "    options['path_length'] = 3\n",
    "    options['LSTM_layers'] = LSTM_layers\n",
    "    options['max_num_actions'] = 40\n",
    "    options['gnn_layer'] = 1\n",
    "\n",
    "    #hyperparameters\n",
    "    options['test_rollouts'] = 40\n",
    "    options['num_rollouts'] = 20\n",
    "    options['batch_size'] = batch_size\n",
    "    options['eval_batch_size'] = 32\n",
    "    options['beta'] = beta\n",
    "    options['Lambda'] = Lambda\n",
    "    options['gamma'] = 1\n",
    "    options['positive_reward'] = 1\n",
    "    options['negative_reward'] = 0\n",
    "    options['learning_rate'] = learning_rate\n",
    "    options['grad_clip_norm'] = 100\n",
    "    options['eval_every'] = 100\n",
    "    options['total_iterations'] = 2000*(64/batch_size)\n",
    "    options['pool'] = 'max'\n",
    "    \n",
    "    return options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fa2d7e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from model.ours2 import *\n",
    "\n",
    "# results = {}\n",
    "# for layer in [1, 2]:\n",
    "#     for bl in [0.05, 0.08, 0.02, 0.1, 0.12, 0.15]:\n",
    "#         for bs in [32, 64, 128]:\n",
    "#             for lr in [5e-4, 1e-4, 1e-3, 5e-5, 2e-3]:\n",
    "#                 params = set_params(layer, bs, bl, bl, lr)\n",
    "#                 name = f'{layer}-{bs}-{bl}-{bl}-{lr}'\n",
    "\n",
    "#                 trainer = Trainer(params)\n",
    "#                 trainer.train()\n",
    "#                 torch.cuda.empty_cache()\n",
    "\n",
    "#                 trainer.agent.load_state_dict(torch.load(params['model_dir'] + 'agent.ckpt'))\n",
    "#                 trainer.agent.eval()\n",
    "#                 trainer.test_environment = trainer.test_test_environment\n",
    "#                 tmp = trainer.test(beam=True, print_paths=False, save_model=False)\n",
    "\n",
    "#                 print(name)\n",
    "#                 print(tmp)\n",
    "#                 print('-------------')\n",
    "#                 results[name] = tmp\n",
    "\n",
    "#                 with open(params['output_dir'] + 'results_table.pk5', 'wb') as f:\n",
    "#                     pickle5.dump(results, f)\n",
    "                    \n",
    "#                 del trainer\n",
    "#                 gc.collect()\n",
    "#                 torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f32d3aa5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/Ours/model/ours2.py:334: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.2393, rewards: 0.1539\n",
      "Iteration: 20, Train loss: -0.2489, rewards: 0.2657\n",
      "Iteration: 30, Train loss: -0.2024, rewards: 0.3343\n",
      "Iteration: 40, Train loss: -0.1534, rewards: 0.3567\n",
      "Iteration: 50, Train loss: -0.1759, rewards: 0.4037\n",
      "Iteration: 60, Train loss: -0.1742, rewards: 0.4152\n",
      "Iteration: 70, Train loss: -0.2533, rewards: 0.4197\n",
      "Iteration: 80, Train loss: -0.2102, rewards: 0.4316\n",
      "Iteration: 90, Train loss: -0.2325, rewards: 0.4564\n",
      "Iteration: 100, Train loss: -0.1425, rewards: 0.5050\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/Research/GraphRL/Ours/model/ours2.py:636: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.5635, Hits@3: 0.6188, Hits@10: 0.6575, MRR: 0.5968\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.1045, rewards: 0.4862\n",
      "Iteration: 120, Train loss: -0.1145, rewards: 0.5096\n",
      "Iteration: 130, Train loss: -0.1473, rewards: 0.5255\n",
      "Iteration: 140, Train loss: -0.1455, rewards: 0.5113\n",
      "Iteration: 150, Train loss: -0.1219, rewards: 0.5188\n",
      "Iteration: 160, Train loss: -0.1274, rewards: 0.5380\n",
      "Iteration: 170, Train loss: -0.1677, rewards: 0.5602\n",
      "Iteration: 180, Train loss: -0.1062, rewards: 0.5358\n",
      "Iteration: 190, Train loss: -0.1052, rewards: 0.5258\n",
      "Iteration: 200, Train loss: -0.1074, rewards: 0.5085\n",
      "Eval:\n",
      "Hits@1: 0.5764, Hits@3: 0.6317, Hits@10: 0.6667, MRR: 0.6084\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.0788, rewards: 0.4927\n",
      "Iteration: 220, Train loss: -0.0691, rewards: 0.5439\n",
      "Iteration: 230, Train loss: -0.1100, rewards: 0.5681\n",
      "Iteration: 240, Train loss: -0.1106, rewards: 0.5698\n",
      "Iteration: 250, Train loss: -0.0669, rewards: 0.5316\n",
      "Iteration: 260, Train loss: -0.0654, rewards: 0.5710\n",
      "Iteration: 270, Train loss: -0.0671, rewards: 0.5410\n",
      "Iteration: 280, Train loss: -0.0495, rewards: 0.5615\n",
      "Iteration: 290, Train loss: -0.0977, rewards: 0.5984\n",
      "Iteration: 300, Train loss: -0.0229, rewards: 0.5527\n",
      "Eval:\n",
      "Hits@1: 0.5783, Hits@3: 0.6464, Hits@10: 0.6759, MRR: 0.6153\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.0529, rewards: 0.5668\n",
      "Iteration: 320, Train loss: -0.0766, rewards: 0.5563\n",
      "Iteration: 330, Train loss: -0.0746, rewards: 0.5914\n",
      "Iteration: 340, Train loss: -0.1019, rewards: 0.5998\n",
      "Iteration: 350, Train loss: -0.0679, rewards: 0.5755\n",
      "Iteration: 360, Train loss: -0.0502, rewards: 0.5777\n",
      "Iteration: 370, Train loss: -0.0368, rewards: 0.5756\n",
      "Iteration: 380, Train loss: -0.0505, rewards: 0.5667\n",
      "Iteration: 390, Train loss: 0.0050, rewards: 0.5772\n",
      "Iteration: 400, Train loss: -0.0638, rewards: 0.5350\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6409, Hits@10: 0.6759, MRR: 0.6140\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.0575, rewards: 0.5977\n",
      "Iteration: 420, Train loss: -0.0533, rewards: 0.5901\n",
      "Iteration: 430, Train loss: -0.0197, rewards: 0.5612\n",
      "Iteration: 440, Train loss: -0.0191, rewards: 0.5932\n",
      "Iteration: 450, Train loss: -0.0486, rewards: 0.5963\n",
      "Iteration: 460, Train loss: -0.0368, rewards: 0.5745\n",
      "Iteration: 470, Train loss: -0.0105, rewards: 0.5920\n",
      "Iteration: 480, Train loss: 0.0403, rewards: 0.5571\n",
      "Iteration: 490, Train loss: -0.0085, rewards: 0.5804\n",
      "Iteration: 500, Train loss: -0.0376, rewards: 0.5905\n",
      "Eval:\n",
      "Hits@1: 0.5912, Hits@3: 0.6483, Hits@10: 0.6740, MRR: 0.6220\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.0327, rewards: 0.6038\n",
      "Iteration: 520, Train loss: -0.0207, rewards: 0.5924\n",
      "Iteration: 530, Train loss: -0.0264, rewards: 0.6000\n",
      "Iteration: 540, Train loss: 0.0296, rewards: 0.5666\n",
      "Iteration: 550, Train loss: -0.0306, rewards: 0.5673\n",
      "Iteration: 560, Train loss: -0.0187, rewards: 0.6138\n",
      "Iteration: 570, Train loss: 0.0128, rewards: 0.6141\n",
      "Iteration: 580, Train loss: -0.0257, rewards: 0.6073\n",
      "Iteration: 590, Train loss: 0.0075, rewards: 0.5884\n",
      "Iteration: 600, Train loss: 0.0255, rewards: 0.6222\n",
      "Eval:\n",
      "Hits@1: 0.5912, Hits@3: 0.6409, Hits@10: 0.6722, MRR: 0.6200\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.0392, rewards: 0.5687\n",
      "Iteration: 620, Train loss: -0.0237, rewards: 0.5820\n",
      "Iteration: 630, Train loss: 0.0182, rewards: 0.5952\n",
      "Iteration: 640, Train loss: -0.0226, rewards: 0.5988\n",
      "Iteration: 650, Train loss: -0.0287, rewards: 0.6130\n",
      "Iteration: 660, Train loss: -0.0131, rewards: 0.5743\n",
      "Iteration: 670, Train loss: -0.0328, rewards: 0.6060\n",
      "Iteration: 680, Train loss: -0.0321, rewards: 0.5941\n",
      "Iteration: 690, Train loss: -0.0035, rewards: 0.6076\n",
      "Iteration: 700, Train loss: 0.0059, rewards: 0.5884\n",
      "Eval:\n",
      "Hits@1: 0.5912, Hits@3: 0.6446, Hits@10: 0.6722, MRR: 0.6216\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.0170, rewards: 0.6211\n",
      "Iteration: 720, Train loss: -0.0657, rewards: 0.6388\n",
      "Iteration: 730, Train loss: -0.0255, rewards: 0.5988\n",
      "Iteration: 740, Train loss: -0.0283, rewards: 0.6299\n",
      "Iteration: 750, Train loss: 0.0215, rewards: 0.6510\n",
      "Iteration: 760, Train loss: -0.0082, rewards: 0.6265\n",
      "Iteration: 770, Train loss: 0.0376, rewards: 0.6211\n",
      "Iteration: 780, Train loss: 0.0090, rewards: 0.6155\n",
      "Iteration: 790, Train loss: 0.0001, rewards: 0.6101\n",
      "Iteration: 800, Train loss: 0.0038, rewards: 0.6209\n",
      "Eval:\n",
      "Hits@1: 0.5691, Hits@3: 0.6427, Hits@10: 0.6685, MRR: 0.6076\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.0298, rewards: 0.6241\n",
      "Iteration: 820, Train loss: 0.0149, rewards: 0.6089\n",
      "Iteration: 830, Train loss: 0.0284, rewards: 0.6072\n",
      "Iteration: 840, Train loss: -0.0474, rewards: 0.6209\n",
      "Iteration: 850, Train loss: 0.0372, rewards: 0.6494\n",
      "Iteration: 860, Train loss: 0.0380, rewards: 0.6112\n",
      "Iteration: 870, Train loss: 0.0447, rewards: 0.5866\n",
      "Iteration: 880, Train loss: 0.0098, rewards: 0.6316\n",
      "Iteration: 890, Train loss: 0.0201, rewards: 0.6159\n",
      "Iteration: 900, Train loss: 0.0026, rewards: 0.6305\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6464, Hits@10: 0.6759, MRR: 0.6207\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: 0.0007, rewards: 0.6266\n",
      "Iteration: 920, Train loss: -0.0122, rewards: 0.6338\n",
      "Iteration: 930, Train loss: -0.0014, rewards: 0.6320\n",
      "Iteration: 940, Train loss: 0.0415, rewards: 0.5951\n",
      "Iteration: 950, Train loss: -0.0047, rewards: 0.6273\n",
      "Iteration: 960, Train loss: -0.0008, rewards: 0.6300\n",
      "Iteration: 970, Train loss: -0.0054, rewards: 0.6082\n",
      "Iteration: 980, Train loss: 0.0671, rewards: 0.6015\n",
      "Iteration: 990, Train loss: 0.0503, rewards: 0.6523\n",
      "Iteration: 1000, Train loss: 0.0243, rewards: 0.6185\n",
      "Eval:\n",
      "Hits@1: 0.5820, Hits@3: 0.6446, Hits@10: 0.6648, MRR: 0.6148\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.0097, rewards: 0.6110\n",
      "Iteration: 1020, Train loss: 0.0104, rewards: 0.6025\n",
      "Iteration: 1030, Train loss: -0.0564, rewards: 0.6221\n",
      "Iteration: 1040, Train loss: -0.0060, rewards: 0.6155\n",
      "Iteration: 1050, Train loss: 0.0377, rewards: 0.6049\n",
      "Iteration: 1060, Train loss: 0.0307, rewards: 0.6164\n",
      "Iteration: 1070, Train loss: 0.0149, rewards: 0.6158\n",
      "Iteration: 1080, Train loss: 0.0114, rewards: 0.5970\n",
      "Iteration: 1090, Train loss: 0.0252, rewards: 0.6055\n",
      "Iteration: 1100, Train loss: 0.0068, rewards: 0.6188\n",
      "Eval:\n",
      "Hits@1: 0.5967, Hits@3: 0.6427, Hits@10: 0.6667, MRR: 0.6228\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: 0.0625, rewards: 0.6205\n",
      "Iteration: 1120, Train loss: -0.0088, rewards: 0.6545\n",
      "Iteration: 1130, Train loss: 0.0020, rewards: 0.6466\n",
      "Iteration: 1140, Train loss: -0.0093, rewards: 0.6431\n",
      "Iteration: 1150, Train loss: 0.0375, rewards: 0.5997\n",
      "Iteration: 1160, Train loss: -0.0087, rewards: 0.6335\n",
      "Iteration: 1170, Train loss: 0.0026, rewards: 0.6276\n",
      "Iteration: 1180, Train loss: 0.0375, rewards: 0.5802\n",
      "Iteration: 1190, Train loss: -0.0110, rewards: 0.6297\n",
      "Iteration: 1200, Train loss: 0.0054, rewards: 0.6334\n",
      "Eval:\n",
      "Hits@1: 0.5967, Hits@3: 0.6354, Hits@10: 0.6667, MRR: 0.6214\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: 0.0381, rewards: 0.6174\n",
      "Iteration: 1220, Train loss: -0.0042, rewards: 0.6352\n",
      "Iteration: 1230, Train loss: 0.0648, rewards: 0.6358\n",
      "Iteration: 1240, Train loss: 0.0147, rewards: 0.6326\n",
      "Iteration: 1250, Train loss: 0.0046, rewards: 0.6581\n",
      "Iteration: 1260, Train loss: 0.0441, rewards: 0.6521\n",
      "Iteration: 1270, Train loss: 0.0019, rewards: 0.6048\n",
      "Iteration: 1280, Train loss: 0.0020, rewards: 0.6070\n",
      "Iteration: 1290, Train loss: -0.0548, rewards: 0.6289\n",
      "Iteration: 1300, Train loss: 0.0288, rewards: 0.6350\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6390, Hits@10: 0.6703, MRR: 0.6157\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: 0.0583, rewards: 0.6473\n",
      "Iteration: 1320, Train loss: -0.0151, rewards: 0.6337\n",
      "Iteration: 1330, Train loss: -0.0210, rewards: 0.6152\n",
      "Iteration: 1340, Train loss: -0.0273, rewards: 0.6358\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1350, Train loss: 0.0255, rewards: 0.6477\n",
      "Iteration: 1360, Train loss: 0.0055, rewards: 0.6209\n",
      "Iteration: 1370, Train loss: 0.0294, rewards: 0.6038\n",
      "Iteration: 1380, Train loss: 0.0301, rewards: 0.6303\n",
      "Iteration: 1390, Train loss: 0.0032, rewards: 0.6216\n",
      "Iteration: 1400, Train loss: 0.0831, rewards: 0.6198\n",
      "Eval:\n",
      "Hits@1: 0.5783, Hits@3: 0.6409, Hits@10: 0.6703, MRR: 0.6121\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: 0.0360, rewards: 0.6084\n",
      "Iteration: 1420, Train loss: 0.0064, rewards: 0.6007\n",
      "Iteration: 1430, Train loss: 0.0425, rewards: 0.6370\n",
      "Iteration: 1440, Train loss: 0.0081, rewards: 0.6320\n",
      "Iteration: 1450, Train loss: 0.0015, rewards: 0.6252\n",
      "Iteration: 1460, Train loss: -0.0307, rewards: 0.6225\n",
      "Iteration: 1470, Train loss: 0.0652, rewards: 0.5968\n",
      "Iteration: 1480, Train loss: 0.0283, rewards: 0.6482\n",
      "Iteration: 1490, Train loss: 0.0487, rewards: 0.6434\n",
      "Iteration: 1500, Train loss: 0.0440, rewards: 0.5970\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6298, Hits@10: 0.6630, MRR: 0.6105\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: 0.0129, rewards: 0.6492\n",
      "Iteration: 1520, Train loss: 0.0373, rewards: 0.6548\n",
      "Iteration: 1530, Train loss: 0.0510, rewards: 0.6273\n",
      "Iteration: 1540, Train loss: -0.0326, rewards: 0.6467\n",
      "Iteration: 1550, Train loss: 0.0507, rewards: 0.6560\n",
      "Iteration: 1560, Train loss: 0.0386, rewards: 0.6510\n",
      "Iteration: 1570, Train loss: -0.0080, rewards: 0.6683\n",
      "Iteration: 1580, Train loss: 0.0411, rewards: 0.6226\n",
      "Iteration: 1590, Train loss: 0.0190, rewards: 0.6375\n",
      "Iteration: 1600, Train loss: 0.0511, rewards: 0.6657\n",
      "Eval:\n",
      "Hits@1: 0.6004, Hits@3: 0.6409, Hits@10: 0.6740, MRR: 0.6254\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: 0.0218, rewards: 0.6229\n",
      "Iteration: 1620, Train loss: 0.0484, rewards: 0.6513\n",
      "Iteration: 1630, Train loss: 0.0147, rewards: 0.6434\n",
      "Iteration: 1640, Train loss: 0.0179, rewards: 0.6412\n",
      "Iteration: 1650, Train loss: 0.0320, rewards: 0.6314\n",
      "Iteration: 1660, Train loss: -0.0517, rewards: 0.6519\n",
      "Iteration: 1670, Train loss: -0.0136, rewards: 0.6137\n",
      "Iteration: 1680, Train loss: 0.0127, rewards: 0.5876\n",
      "Iteration: 1690, Train loss: 0.0008, rewards: 0.6593\n",
      "Iteration: 1700, Train loss: 0.0256, rewards: 0.6538\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6446, Hits@10: 0.6740, MRR: 0.6197\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: 0.0158, rewards: 0.6161\n",
      "Iteration: 1720, Train loss: 0.0182, rewards: 0.6083\n",
      "Iteration: 1730, Train loss: 0.0044, rewards: 0.6378\n",
      "Iteration: 1740, Train loss: 0.0597, rewards: 0.6379\n",
      "Iteration: 1750, Train loss: 0.0122, rewards: 0.6355\n",
      "Iteration: 1760, Train loss: 0.0146, rewards: 0.6453\n",
      "Iteration: 1770, Train loss: -0.0010, rewards: 0.6463\n",
      "Iteration: 1780, Train loss: -0.0005, rewards: 0.6608\n",
      "Iteration: 1790, Train loss: -0.0349, rewards: 0.6702\n",
      "Iteration: 1800, Train loss: 0.0161, rewards: 0.6492\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6372, Hits@10: 0.6740, MRR: 0.6177\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: 0.0237, rewards: 0.6152\n",
      "Iteration: 1820, Train loss: 0.0514, rewards: 0.6279\n",
      "Iteration: 1830, Train loss: -0.0306, rewards: 0.6323\n",
      "Iteration: 1840, Train loss: 0.0478, rewards: 0.6125\n",
      "Iteration: 1850, Train loss: 0.0149, rewards: 0.6420\n",
      "Iteration: 1860, Train loss: -0.0154, rewards: 0.6252\n",
      "Iteration: 1870, Train loss: 0.0458, rewards: 0.6175\n",
      "Iteration: 1880, Train loss: 0.0012, rewards: 0.6644\n",
      "Iteration: 1890, Train loss: 0.0692, rewards: 0.6453\n",
      "Iteration: 1900, Train loss: -0.0124, rewards: 0.6283\n",
      "Eval:\n",
      "Hits@1: 0.5912, Hits@3: 0.6427, Hits@10: 0.6703, MRR: 0.6189\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: 0.0143, rewards: 0.6830\n",
      "Iteration: 1920, Train loss: 0.0023, rewards: 0.6404\n",
      "Iteration: 1930, Train loss: 0.0187, rewards: 0.6491\n",
      "Iteration: 1940, Train loss: 0.0222, rewards: 0.6416\n",
      "Iteration: 1950, Train loss: 0.0026, rewards: 0.6626\n",
      "Iteration: 1960, Train loss: 0.0525, rewards: 0.6371\n",
      "Iteration: 1970, Train loss: 0.0514, rewards: 0.6445\n",
      "Iteration: 1980, Train loss: 0.0186, rewards: 0.6590\n",
      "Iteration: 1990, Train loss: 0.0248, rewards: 0.6480\n",
      "Iteration: 2000, Train loss: -0.0219, rewards: 0.6686\n",
      "Eval:\n",
      "Hits@1: 0.5709, Hits@3: 0.6372, Hits@10: 0.6685, MRR: 0.6073\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "from model.ours2 import *\n",
    "params = set_params(2, 64, 0.1, 0.1, 0.0001)\n",
    "trainer = Trainer(params)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "71ccaaec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.6852, Hits@3: 0.7825, Hits@10: 0.8226, MRR: 0.7390\n"
     ]
    }
   ],
   "source": [
    "trainer.agent.eval()\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "test_results = trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8233d673",
   "metadata": {},
   "source": [
    "# Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "39ee57c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "with open('./outputs_NELL-995_v7-tune2/results_table.pk5', 'rb') as f:\n",
    "    results = pickle5.load(f)\n",
    "check = pd.DataFrame([[k, float(v.split(': ')[-1])] for k,v in results.items()], columns = ['config', 'mrr'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b1db462f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>config</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>2-32-0.12-0.12-0.0001</td>\n",
       "      <td>0.7389</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>115</th>\n",
       "      <td>2-128-0.1-0.1-0.0001</td>\n",
       "      <td>0.7392</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>1-64-0.18-0.18-5e-05</td>\n",
       "      <td>0.7394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>1-32-0.12-0.12-0.0001</td>\n",
       "      <td>0.7397</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90</th>\n",
       "      <td>1-32-0.18-0.18-0.0001</td>\n",
       "      <td>0.7398</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1-64-0.1-0.1-5e-05</td>\n",
       "      <td>0.7407</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>140</th>\n",
       "      <td>2-64-0.12-0.12-0.0001</td>\n",
       "      <td>0.7409</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>195</th>\n",
       "      <td>2-32-0.18-0.18-0.0001</td>\n",
       "      <td>0.7412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>200</th>\n",
       "      <td>2-64-0.18-0.18-0.0001</td>\n",
       "      <td>0.7416</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>110</th>\n",
       "      <td>2-64-0.1-0.1-0.0001</td>\n",
       "      <td>0.7427</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    config     mrr\n",
       "135  2-32-0.12-0.12-0.0001  0.7389\n",
       "115   2-128-0.1-0.1-0.0001  0.7392\n",
       "96    1-64-0.18-0.18-5e-05  0.7394\n",
       "30   1-32-0.12-0.12-0.0001  0.7397\n",
       "90   1-32-0.18-0.18-0.0001  0.7398\n",
       "6       1-64-0.1-0.1-5e-05  0.7407\n",
       "140  2-64-0.12-0.12-0.0001  0.7409\n",
       "195  2-32-0.18-0.18-0.0001  0.7412\n",
       "200  2-64-0.18-0.18-0.0001  0.7416\n",
       "110    2-64-0.1-0.1-0.0001  0.7427"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "check.sort_values('mrr').tail(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "af12e6f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_params(LSTM_layers = 1, batch_size = 8, beta = 0.15, Lambda = 0.15, learning_rate = 5e-5):\n",
    "    options = {}\n",
    "\n",
    "    #basic setting\n",
    "    options['use_cuda'] = True\n",
    "    options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/nell/vocab/'\n",
    "    options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/nell/'\n",
    "    options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "    options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "    options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "    options['model_dir'] = './outputs_NELL-995_v7-tune2/'\n",
    "    options['output_dir'] = './outputs_NELL-995_v7-tune2/'\n",
    "\n",
    "    #agent setting\n",
    "    options['pretrained_embeddings_relation'] = {}\n",
    "    options['pretrained_embeddings_entity'] = {}\n",
    "    options['embedding_size'] = 50\n",
    "    options['hidden_size'] = 200\n",
    "    options['use_entity_embeddings'] = 1\n",
    "    options['train_entity_embeddings'] = 1\n",
    "    options['train_relation_embeddings'] = 1\n",
    "    options['path_length'] = 3\n",
    "    options['LSTM_layers'] = LSTM_layers\n",
    "    options['max_num_actions'] = 100\n",
    "    options['gnn_layer'] = 1\n",
    "\n",
    "    #hyperparameters\n",
    "    options['test_rollouts'] = 100\n",
    "    options['num_rollouts'] = 20\n",
    "    options['batch_size'] = batch_size\n",
    "    options['eval_batch_size'] = 8\n",
    "    options['beta'] = beta\n",
    "    options['Lambda'] = Lambda\n",
    "    options['gamma'] = 1\n",
    "    options['positive_reward'] = 1\n",
    "    options['negative_reward'] = 0\n",
    "    options['learning_rate'] = learning_rate\n",
    "    options['grad_clip_norm'] = 100\n",
    "    options['eval_every'] = 100\n",
    "    options['total_iterations'] = 2000*(64/batch_size)\n",
    "    options['pool'] = 'max'\n",
    "    \n",
    "    return options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ef0f8101",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "CUDA out of memory. Tried to allocate 4.77 GiB (GPU 0; 39.50 GiB total capacity; 35.90 GiB already allocated; 1.87 GiB free; 36.60 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m params \u001b[38;5;241m=\u001b[39m set_params(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m32\u001b[39m, \u001b[38;5;241m0.18\u001b[39m, \u001b[38;5;241m0.18\u001b[39m, \u001b[38;5;241m0.0001\u001b[39m)\n\u001b[1;32m      3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(params)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Research/GraphRL/Ours/model/ours2.py:510\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    507\u001b[0m next_neighbors \u001b[38;5;241m=\u001b[39m next_neighbors[next_possible_entities\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()]\n\u001b[1;32m    508\u001b[0m next_neighbors \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mLongTensor(next_neighbors)\u001b[38;5;241m.\u001b[39mto(current_entities\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m--> 510\u001b[0m loss, entity_state_emb, logits, idx, chosen_relation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    511\u001b[0m \u001b[43m    \u001b[49m\u001b[43mnext_possible_relations\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    512\u001b[0m \u001b[43m    \u001b[49m\u001b[43mnext_possible_entities\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mentity_state_emb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    513\u001b[0m \u001b[43m    \u001b[49m\u001b[43mprev_relation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery_relation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    514\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcurrent_entities\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    515\u001b[0m \u001b[43m    \u001b[49m\u001b[43mnext_neighbors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprev_gnn_state\u001b[49m\n\u001b[1;32m    516\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    518\u001b[0m entity_state \u001b[38;5;241m=\u001b[39m entity_episode(idx\u001b[38;5;241m.\u001b[39mcpu())\n\u001b[1;32m    519\u001b[0m next_possible_relations \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(entity_state[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnext_relations\u001b[39m\u001b[38;5;124m'\u001b[39m])\u001b[38;5;241m.\u001b[39mlong()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n",
      "File \u001b[0;32m~/Research/GraphRL/Ours/model/ours2.py:302\u001b[0m, in \u001b[0;36mAgent.step\u001b[0;34m(self, next_relations, next_entities, prev_state, prev_relation, query_embedding, current_entities, next_neighbors, prev_gnn_state)\u001b[0m\n\u001b[1;32m    299\u001b[0m     state \u001b[38;5;241m=\u001b[39m output\n\u001b[1;32m    301\u001b[0m candidate_action_embeddings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maction_encoder(next_relations, next_entities)\n\u001b[0;32m--> 302\u001b[0m gnn_embedding \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mneighbour_aggregation\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery_embedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprev_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnext_neighbors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    303\u001b[0m candidate_action_embeddings \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([candidate_action_embeddings, gnn_embedding], \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m    305\u001b[0m query_embedding \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrel_embed[query_embedding]\n",
      "File \u001b[0;32m~/Research/GraphRL/Ours/model/ours2.py:266\u001b[0m, in \u001b[0;36mAgent.neighbour_aggregation\u001b[0;34m(self, query_relation, prev_state, next_neighbors)\u001b[0m\n\u001b[1;32m    263\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate_encoder(torch\u001b[38;5;241m.\u001b[39mcat([lstm_state, query_embedding], \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m    264\u001b[0m neighbor_embedding \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([t_embed, r_embed], \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 266\u001b[0m att \u001b[38;5;241m=\u001b[39m (\u001b[43mstate\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mneighbor_embedding\u001b[49m)\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m/\u001b[39mnp\u001b[38;5;241m.\u001b[39msqrt(state\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m    267\u001b[0m att \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39msoftmax(att \u001b[38;5;241m-\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m mask)\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m1e8\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m    268\u001b[0m update_embedding \u001b[38;5;241m=\u001b[39m (neighbor_embedding\u001b[38;5;241m*\u001b[39matt\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m))\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m2\u001b[39m)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 4.77 GiB (GPU 0; 39.50 GiB total capacity; 35.90 GiB already allocated; 1.87 GiB free; 36.60 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
     ]
    }
   ],
   "source": [
    "from model.ours2 import *\n",
    "params = set_params(2, 32, 0.18, 0.18, 0.0001)\n",
    "trainer = Trainer(params)\n",
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
