{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ce63a777",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-dc67f6ae-4c27-5869-bcd2-8560a2da46c7\"\n",
    "\n",
    "from model.ours3 import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "398c7676",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = {}\n",
    "\n",
    "#basic setting\n",
    "options['use_cuda'] = True\n",
    "options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/nell/vocab/'\n",
    "options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/nell/'\n",
    "options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "options['model_dir'] = './outputs_nell995-1/'\n",
    "options['output_dir'] = './outputs_nell995-1/'\n",
    "\n",
    "#agent setting\n",
    "options['pretrained_embeddings_relation'] = {}\n",
    "options['pretrained_embeddings_entity'] = {}\n",
    "options['embedding_size'] = 50\n",
    "options['hidden_size'] = 200\n",
    "options['use_entity_embeddings'] = 1\n",
    "options['train_entity_embeddings'] = 1\n",
    "options['train_relation_embeddings'] = 1\n",
    "options['path_length'] = 3\n",
    "options['LSTM_layers'] = 1\n",
    "options['max_num_actions'] = 40\n",
    "options['gnn_layer'] = 2\n",
    "\n",
    "#hyperparameters\n",
    "options['test_rollouts'] = 40\n",
    "options['num_rollouts'] = 20\n",
    "options['batch_size'] = 64\n",
    "options['eval_batch_size'] = 32\n",
    "options['beta'] = 0.15\n",
    "options['Lambda'] = 0.15\n",
    "options['gamma'] = 1\n",
    "options['positive_reward'] = 1\n",
    "options['negative_reward'] = 0\n",
    "options['learning_rate'] = 0.00005\n",
    "options['grad_clip_norm'] = 100\n",
    "options['eval_every'] = 100\n",
    "options['total_iterations'] = 2000*(64/options['batch_size'])\n",
    "options['pool'] = 'max'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d9b9fda",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/experiments/model/ours3.py:334: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.2849, rewards: 0.0887\n",
      "Iteration: 20, Train loss: -0.3635, rewards: 0.1655\n",
      "Iteration: 30, Train loss: -0.3370, rewards: 0.1962\n",
      "Iteration: 40, Train loss: -0.3271, rewards: 0.2548\n",
      "Iteration: 50, Train loss: -0.3216, rewards: 0.3219\n",
      "Iteration: 60, Train loss: -0.3063, rewards: 0.3034\n",
      "Iteration: 70, Train loss: -0.3267, rewards: 0.3578\n",
      "Iteration: 80, Train loss: -0.3476, rewards: 0.3498\n",
      "Iteration: 90, Train loss: -0.2836, rewards: 0.3776\n",
      "Iteration: 100, Train loss: -0.2867, rewards: 0.3848\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/Research/GraphRL/experiments/model/ours3.py:636: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.5175, Hits@3: 0.5912, Hits@10: 0.6464, MRR: 0.5606\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.2682, rewards: 0.3937\n",
      "Iteration: 120, Train loss: -0.2802, rewards: 0.4339\n",
      "Iteration: 130, Train loss: -0.3108, rewards: 0.4543\n",
      "Iteration: 140, Train loss: -0.2136, rewards: 0.4789\n",
      "Iteration: 150, Train loss: -0.3061, rewards: 0.4744\n",
      "Iteration: 160, Train loss: -0.2989, rewards: 0.4774\n",
      "Iteration: 170, Train loss: -0.3021, rewards: 0.4719\n",
      "Iteration: 180, Train loss: -0.2649, rewards: 0.5019\n",
      "Iteration: 190, Train loss: -0.2646, rewards: 0.4761\n",
      "Iteration: 200, Train loss: -0.1976, rewards: 0.5541\n",
      "Eval:\n",
      "Hits@1: 0.5746, Hits@3: 0.6390, Hits@10: 0.6648, MRR: 0.6104\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.1877, rewards: 0.4798\n",
      "Iteration: 220, Train loss: -0.2330, rewards: 0.4951\n",
      "Iteration: 230, Train loss: -0.2203, rewards: 0.5067\n",
      "Iteration: 240, Train loss: -0.2068, rewards: 0.5319\n",
      "Iteration: 250, Train loss: -0.2151, rewards: 0.4985\n",
      "Iteration: 260, Train loss: -0.2786, rewards: 0.5167\n",
      "Iteration: 270, Train loss: -0.2483, rewards: 0.5384\n",
      "Iteration: 280, Train loss: -0.2198, rewards: 0.5285\n",
      "Iteration: 290, Train loss: -0.2307, rewards: 0.5009\n",
      "Iteration: 300, Train loss: -0.2038, rewards: 0.5414\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6390, Hits@10: 0.6722, MRR: 0.6194\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.2190, rewards: 0.5395\n",
      "Iteration: 320, Train loss: -0.1759, rewards: 0.5405\n",
      "Iteration: 330, Train loss: -0.2200, rewards: 0.5400\n",
      "Iteration: 340, Train loss: -0.1307, rewards: 0.5144\n",
      "Iteration: 350, Train loss: -0.1777, rewards: 0.5427\n",
      "Iteration: 360, Train loss: -0.2044, rewards: 0.5526\n",
      "Iteration: 370, Train loss: -0.2145, rewards: 0.5141\n",
      "Iteration: 380, Train loss: -0.2240, rewards: 0.5570\n",
      "Iteration: 390, Train loss: -0.2025, rewards: 0.5312\n",
      "Iteration: 400, Train loss: -0.1438, rewards: 0.5155\n",
      "Eval:\n",
      "Hits@1: 0.5930, Hits@3: 0.6409, Hits@10: 0.6740, MRR: 0.6222\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.1758, rewards: 0.5492\n",
      "Iteration: 420, Train loss: -0.1613, rewards: 0.5316\n",
      "Iteration: 430, Train loss: -0.2134, rewards: 0.5268\n",
      "Iteration: 440, Train loss: -0.1520, rewards: 0.5395\n",
      "Iteration: 450, Train loss: -0.1677, rewards: 0.5993\n",
      "Iteration: 460, Train loss: -0.1929, rewards: 0.5452\n",
      "Iteration: 470, Train loss: -0.1577, rewards: 0.5237\n",
      "Iteration: 480, Train loss: -0.1779, rewards: 0.5484\n",
      "Iteration: 490, Train loss: -0.1778, rewards: 0.5312\n",
      "Iteration: 500, Train loss: -0.1936, rewards: 0.5634\n",
      "Eval:\n",
      "Hits@1: 0.5820, Hits@3: 0.6335, Hits@10: 0.6685, MRR: 0.6129\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.2138, rewards: 0.5585\n",
      "Iteration: 520, Train loss: -0.2549, rewards: 0.5404\n",
      "Iteration: 530, Train loss: -0.1956, rewards: 0.5713\n",
      "Iteration: 540, Train loss: -0.2011, rewards: 0.5480\n",
      "Iteration: 550, Train loss: -0.1435, rewards: 0.5890\n",
      "Iteration: 560, Train loss: -0.1684, rewards: 0.5823\n",
      "Iteration: 570, Train loss: -0.2102, rewards: 0.5770\n",
      "Iteration: 580, Train loss: -0.2303, rewards: 0.5685\n",
      "Iteration: 590, Train loss: -0.1232, rewards: 0.5731\n",
      "Iteration: 600, Train loss: -0.1392, rewards: 0.5559\n",
      "Eval:\n",
      "Hits@1: 0.5985, Hits@3: 0.6409, Hits@10: 0.6740, MRR: 0.6255\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.1706, rewards: 0.5481\n",
      "Iteration: 620, Train loss: -0.1146, rewards: 0.5534\n",
      "Iteration: 630, Train loss: -0.1488, rewards: 0.5789\n",
      "Iteration: 640, Train loss: -0.1431, rewards: 0.5816\n",
      "Iteration: 650, Train loss: -0.1808, rewards: 0.5928\n",
      "Iteration: 660, Train loss: -0.0842, rewards: 0.5462\n",
      "Iteration: 670, Train loss: -0.1027, rewards: 0.5825\n",
      "Iteration: 680, Train loss: -0.1082, rewards: 0.5443\n",
      "Iteration: 690, Train loss: -0.1820, rewards: 0.6072\n",
      "Iteration: 700, Train loss: -0.0957, rewards: 0.5708\n",
      "Eval:\n",
      "Hits@1: 0.5709, Hits@3: 0.6446, Hits@10: 0.6777, MRR: 0.6106\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.1615, rewards: 0.5977\n",
      "Iteration: 720, Train loss: -0.1155, rewards: 0.5556\n",
      "Iteration: 730, Train loss: -0.1611, rewards: 0.5674\n",
      "Iteration: 740, Train loss: -0.1606, rewards: 0.5477\n",
      "Iteration: 750, Train loss: -0.1430, rewards: 0.5838\n",
      "Iteration: 760, Train loss: -0.1628, rewards: 0.5818\n",
      "Iteration: 770, Train loss: -0.1974, rewards: 0.5556\n",
      "Iteration: 780, Train loss: -0.1134, rewards: 0.5658\n",
      "Iteration: 790, Train loss: -0.1073, rewards: 0.5887\n",
      "Iteration: 800, Train loss: -0.1114, rewards: 0.5507\n",
      "Eval:\n",
      "Iteration: 850, Train loss: -0.1722, rewards: 0.5705\n",
      "Iteration: 860, Train loss: -0.1339, rewards: 0.5845\n",
      "Iteration: 870, Train loss: -0.1266, rewards: 0.5882\n",
      "Iteration: 880, Train loss: -0.0869, rewards: 0.5973\n",
      "Iteration: 890, Train loss: -0.1373, rewards: 0.5924\n",
      "Iteration: 900, Train loss: -0.1082, rewards: 0.5525\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6464, Hits@10: 0.6759, MRR: 0.6192\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.1689, rewards: 0.5800\n",
      "Iteration: 920, Train loss: -0.1181, rewards: 0.6146\n",
      "Iteration: 930, Train loss: -0.1487, rewards: 0.5996\n",
      "Iteration: 940, Train loss: -0.0744, rewards: 0.6410\n",
      "Iteration: 950, Train loss: -0.1066, rewards: 0.5998\n",
      "Iteration: 960, Train loss: -0.1514, rewards: 0.6194\n",
      "Iteration: 970, Train loss: -0.1208, rewards: 0.5751\n",
      "Iteration: 980, Train loss: -0.1091, rewards: 0.6008\n",
      "Iteration: 990, Train loss: -0.1388, rewards: 0.5835\n",
      "Iteration: 1000, Train loss: -0.1183, rewards: 0.6105\n",
      "Eval:\n",
      "Hits@1: 0.5967, Hits@3: 0.6372, Hits@10: 0.6759, MRR: 0.6229\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.1142, rewards: 0.6209\n",
      "Iteration: 1020, Train loss: -0.1292, rewards: 0.5695\n",
      "Iteration: 1030, Train loss: -0.1514, rewards: 0.5789\n",
      "Iteration: 1040, Train loss: -0.1245, rewards: 0.5905\n",
      "Iteration: 1050, Train loss: -0.0664, rewards: 0.6027\n",
      "Iteration: 1060, Train loss: -0.0895, rewards: 0.5974\n",
      "Iteration: 1070, Train loss: -0.0917, rewards: 0.5595\n",
      "Iteration: 1080, Train loss: -0.1678, rewards: 0.6015\n",
      "Iteration: 1090, Train loss: -0.0823, rewards: 0.5906\n",
      "Iteration: 1100, Train loss: -0.1145, rewards: 0.6145\n",
      "Eval:\n",
      "Hits@1: 0.5801, Hits@3: 0.6335, Hits@10: 0.6685, MRR: 0.6114\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.1201, rewards: 0.5942\n",
      "Iteration: 1120, Train loss: -0.0813, rewards: 0.5804\n",
      "Iteration: 1130, Train loss: -0.1167, rewards: 0.5807\n",
      "Iteration: 1140, Train loss: -0.1805, rewards: 0.5634\n",
      "Iteration: 1150, Train loss: -0.1540, rewards: 0.5947\n",
      "Iteration: 1160, Train loss: -0.0989, rewards: 0.6034\n",
      "Iteration: 1170, Train loss: -0.1971, rewards: 0.5824\n",
      "Iteration: 1180, Train loss: -0.0734, rewards: 0.6062\n",
      "Iteration: 1190, Train loss: -0.1430, rewards: 0.5812\n",
      "Iteration: 1200, Train loss: -0.0965, rewards: 0.6032\n",
      "Eval:\n",
      "Hits@1: 0.5875, Hits@3: 0.6390, Hits@10: 0.6703, MRR: 0.6170\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.0923, rewards: 0.6088\n",
      "Iteration: 1220, Train loss: -0.1042, rewards: 0.5926\n",
      "Iteration: 1230, Train loss: -0.1276, rewards: 0.6269\n",
      "Iteration: 1240, Train loss: -0.1426, rewards: 0.5967\n",
      "Iteration: 1250, Train loss: -0.1079, rewards: 0.5717\n",
      "Iteration: 1260, Train loss: -0.1178, rewards: 0.5927\n",
      "Iteration: 1270, Train loss: -0.1072, rewards: 0.5945\n",
      "Iteration: 1280, Train loss: -0.1147, rewards: 0.6024\n",
      "Iteration: 1290, Train loss: -0.1396, rewards: 0.5934\n",
      "Iteration: 1300, Train loss: -0.1322, rewards: 0.6012\n",
      "Eval:\n",
      "Hits@1: 0.6004, Hits@3: 0.6519, Hits@10: 0.6759, MRR: 0.6276\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: -0.1321, rewards: 0.5874\n",
      "Iteration: 1320, Train loss: -0.1309, rewards: 0.6066\n",
      "Iteration: 1330, Train loss: -0.1450, rewards: 0.6141\n",
      "Iteration: 1340, Train loss: -0.0802, rewards: 0.6098\n",
      "Iteration: 1350, Train loss: -0.1172, rewards: 0.6111\n",
      "Iteration: 1360, Train loss: -0.1223, rewards: 0.5962\n",
      "Iteration: 1370, Train loss: -0.1705, rewards: 0.5682\n",
      "Iteration: 1380, Train loss: -0.1734, rewards: 0.6423\n",
      "Iteration: 1390, Train loss: -0.0417, rewards: 0.5802\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1400, Train loss: -0.0824, rewards: 0.5892\n",
      "Eval:\n",
      "Hits@1: 0.5948, Hits@3: 0.6409, Hits@10: 0.6759, MRR: 0.6224\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.1303, rewards: 0.6060\n",
      "Iteration: 1420, Train loss: -0.1331, rewards: 0.5897\n",
      "Iteration: 1430, Train loss: -0.1471, rewards: 0.6061\n",
      "Iteration: 1440, Train loss: -0.0843, rewards: 0.6116\n",
      "Iteration: 1450, Train loss: -0.1204, rewards: 0.6196\n",
      "Iteration: 1460, Train loss: -0.0891, rewards: 0.5923\n",
      "Iteration: 1470, Train loss: -0.1803, rewards: 0.6040\n",
      "Iteration: 1480, Train loss: -0.0836, rewards: 0.6253\n",
      "Iteration: 1490, Train loss: -0.0573, rewards: 0.5867\n",
      "Iteration: 1500, Train loss: -0.1739, rewards: 0.5954\n",
      "Eval:\n",
      "Hits@1: 0.5948, Hits@3: 0.6501, Hits@10: 0.6796, MRR: 0.6258\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.0921, rewards: 0.5984\n",
      "Iteration: 1520, Train loss: -0.1107, rewards: 0.5923\n",
      "Iteration: 1530, Train loss: -0.1095, rewards: 0.6153\n",
      "Iteration: 1540, Train loss: -0.1400, rewards: 0.6033\n",
      "Iteration: 1550, Train loss: -0.1364, rewards: 0.6312\n",
      "Iteration: 1560, Train loss: -0.0707, rewards: 0.6208\n",
      "Iteration: 1570, Train loss: -0.0970, rewards: 0.6205\n",
      "Iteration: 1580, Train loss: -0.1381, rewards: 0.6224\n",
      "Iteration: 1590, Train loss: -0.1412, rewards: 0.6293\n",
      "Iteration: 1600, Train loss: -0.1132, rewards: 0.6330\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6427, Hits@10: 0.6777, MRR: 0.6187\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.1007, rewards: 0.5994\n",
      "Iteration: 1620, Train loss: -0.0715, rewards: 0.6368\n",
      "Iteration: 1630, Train loss: -0.1466, rewards: 0.6166\n",
      "Iteration: 1640, Train loss: -0.0783, rewards: 0.6162\n",
      "Iteration: 1650, Train loss: -0.1520, rewards: 0.6265\n",
      "Iteration: 1660, Train loss: -0.1072, rewards: 0.6044\n",
      "Iteration: 1670, Train loss: -0.1011, rewards: 0.6108\n",
      "Iteration: 1680, Train loss: -0.1061, rewards: 0.6184\n",
      "Iteration: 1690, Train loss: -0.0989, rewards: 0.6063\n",
      "Iteration: 1700, Train loss: -0.0595, rewards: 0.6262\n",
      "Eval:\n",
      "Hits@1: 0.5746, Hits@3: 0.6464, Hits@10: 0.6722, MRR: 0.6118\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.1005, rewards: 0.6038\n",
      "Iteration: 1720, Train loss: -0.1153, rewards: 0.6198\n",
      "Iteration: 1730, Train loss: -0.0780, rewards: 0.6290\n",
      "Iteration: 1740, Train loss: -0.0690, rewards: 0.6288\n",
      "Iteration: 1750, Train loss: -0.0485, rewards: 0.6202\n",
      "Iteration: 1760, Train loss: -0.0836, rewards: 0.6367\n",
      "Iteration: 1770, Train loss: -0.1054, rewards: 0.6048\n",
      "Iteration: 1780, Train loss: -0.1500, rewards: 0.6173\n",
      "Iteration: 1790, Train loss: -0.1416, rewards: 0.6434\n",
      "Iteration: 1800, Train loss: -0.1207, rewards: 0.6243\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6501, Hits@10: 0.6777, MRR: 0.6210\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.1312, rewards: 0.6240\n",
      "Iteration: 1820, Train loss: -0.1398, rewards: 0.6297\n",
      "Iteration: 1830, Train loss: -0.1299, rewards: 0.6060\n",
      "Iteration: 1840, Train loss: -0.1137, rewards: 0.6227\n",
      "Iteration: 1850, Train loss: -0.1606, rewards: 0.6180\n",
      "Iteration: 1860, Train loss: -0.1302, rewards: 0.6210\n",
      "Iteration: 1870, Train loss: -0.0065, rewards: 0.6180\n",
      "Iteration: 1880, Train loss: -0.1234, rewards: 0.6319\n",
      "Iteration: 1890, Train loss: -0.0879, rewards: 0.5993\n",
      "Iteration: 1900, Train loss: -0.1389, rewards: 0.6077\n",
      "Eval:\n",
      "Hits@1: 0.5672, Hits@3: 0.6464, Hits@10: 0.6777, MRR: 0.6089\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.1249, rewards: 0.6080\n",
      "Iteration: 1920, Train loss: -0.1385, rewards: 0.6230\n",
      "Iteration: 1930, Train loss: -0.0379, rewards: 0.6017\n",
      "Iteration: 1940, Train loss: -0.0576, rewards: 0.6024\n",
      "Iteration: 1950, Train loss: -0.1046, rewards: 0.5909\n",
      "Iteration: 1960, Train loss: -0.1461, rewards: 0.6551\n",
      "Iteration: 1970, Train loss: -0.1517, rewards: 0.6365\n",
      "Iteration: 1980, Train loss: -0.0747, rewards: 0.6080\n",
      "Iteration: 1990, Train loss: -0.0251, rewards: 0.6137\n",
      "Iteration: 2000, Train loss: -0.0941, rewards: 0.6141\n",
      "Eval:\n",
      "Hits@1: 0.5820, Hits@3: 0.6501, Hits@10: 0.6722, MRR: 0.6171\n",
      "------------------------------------------------------------\n",
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Hits@1: 0.6845, Hits@3: 0.7970, Hits@10: 0.8502, MRR: 0.7473\n",
      "3\n",
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n",
      "Iteration: 10, Train loss: -0.2623, rewards: 0.0762\n",
      "Iteration: 20, Train loss: -0.4014, rewards: 0.1212\n",
      "Iteration: 30, Train loss: -0.4286, rewards: 0.1350\n",
      "Iteration: 40, Train loss: -0.4376, rewards: 0.1762\n",
      "Iteration: 50, Train loss: -0.4064, rewards: 0.1837\n",
      "Iteration: 60, Train loss: -0.3861, rewards: 0.1700\n",
      "Iteration: 70, Train loss: -0.3899, rewards: 0.2687\n",
      "Iteration: 80, Train loss: -0.1682, rewards: 0.1744\n",
      "Iteration: 90, Train loss: -0.3177, rewards: 0.2512\n",
      "Iteration: 100, Train loss: -0.4360, rewards: 0.3406\n",
      "Eval:\n",
      "Hits@1: 0.4401, Hits@3: 0.5470, Hits@10: 0.6262, MRR: 0.5039\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.1675, rewards: 0.2619\n",
      "Iteration: 120, Train loss: -0.4376, rewards: 0.3362\n",
      "Iteration: 130, Train loss: -0.3108, rewards: 0.3050\n",
      "Iteration: 140, Train loss: -0.2154, rewards: 0.3331\n",
      "Iteration: 150, Train loss: -0.1880, rewards: 0.3387\n",
      "Iteration: 160, Train loss: -0.1466, rewards: 0.3575\n",
      "Iteration: 170, Train loss: -0.3067, rewards: 0.4288\n",
      "Iteration: 180, Train loss: -0.4749, rewards: 0.4756\n",
      "Iteration: 190, Train loss: -0.4060, rewards: 0.4019\n",
      "Iteration: 200, Train loss: -0.3236, rewards: 0.4238\n",
      "Eval:\n",
      "Hits@1: 0.4954, Hits@3: 0.5856, Hits@10: 0.6464, MRR: 0.5474\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.3357, rewards: 0.3775\n",
      "Iteration: 220, Train loss: -0.2664, rewards: 0.3625\n",
      "Iteration: 230, Train loss: -0.0394, rewards: 0.4650\n",
      "Iteration: 240, Train loss: -0.2725, rewards: 0.3663\n",
      "Iteration: 250, Train loss: -0.3256, rewards: 0.4788\n",
      "Iteration: 260, Train loss: -0.3194, rewards: 0.4544\n",
      "Iteration: 270, Train loss: -0.3081, rewards: 0.4238\n",
      "Iteration: 280, Train loss: -0.3106, rewards: 0.4431\n",
      "Iteration: 290, Train loss: -0.2279, rewards: 0.3906\n",
      "Iteration: 300, Train loss: -0.3088, rewards: 0.5356\n",
      "Eval:\n",
      "Hits@1: 0.5543, Hits@3: 0.6114, Hits@10: 0.6703, MRR: 0.5925\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.1484, rewards: 0.4669\n",
      "Iteration: 320, Train loss: -0.2931, rewards: 0.4856\n",
      "Iteration: 330, Train loss: -0.1477, rewards: 0.4763\n",
      "Iteration: 340, Train loss: -0.2463, rewards: 0.4888\n",
      "Iteration: 350, Train loss: -0.2124, rewards: 0.5200\n",
      "Iteration: 360, Train loss: -0.3261, rewards: 0.4556\n",
      "Iteration: 370, Train loss: -0.3195, rewards: 0.5194\n",
      "Iteration: 380, Train loss: -0.3004, rewards: 0.4412\n",
      "Iteration: 390, Train loss: -0.0517, rewards: 0.4437\n",
      "Iteration: 400, Train loss: -0.2363, rewards: 0.6412\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6483, Hits@10: 0.6832, MRR: 0.6205\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.1161, rewards: 0.4519\n",
      "Iteration: 420, Train loss: -0.2404, rewards: 0.5275\n",
      "Iteration: 430, Train loss: -0.0892, rewards: 0.4756\n",
      "Iteration: 440, Train loss: -0.0913, rewards: 0.5737\n",
      "Iteration: 450, Train loss: -0.1208, rewards: 0.4019\n",
      "Iteration: 460, Train loss: -0.1946, rewards: 0.4756\n",
      "Iteration: 470, Train loss: -0.2615, rewards: 0.5481\n",
      "Iteration: 480, Train loss: -0.1918, rewards: 0.5038\n",
      "Iteration: 490, Train loss: -0.1304, rewards: 0.4794\n",
      "Iteration: 500, Train loss: -0.0881, rewards: 0.4606\n",
      "Eval:\n",
      "Hits@1: 0.5691, Hits@3: 0.6446, Hits@10: 0.6777, MRR: 0.6113\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.0826, rewards: 0.4625\n",
      "Iteration: 520, Train loss: -0.1948, rewards: 0.4375\n",
      "Iteration: 530, Train loss: -0.2937, rewards: 0.5100\n",
      "Iteration: 540, Train loss: -0.0770, rewards: 0.4562\n",
      "Iteration: 550, Train loss: -0.2839, rewards: 0.5537\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 560, Train loss: -0.2380, rewards: 0.4694\n",
      "Iteration: 570, Train loss: -0.1768, rewards: 0.5200\n",
      "Iteration: 580, Train loss: -0.2802, rewards: 0.4163\n",
      "Iteration: 590, Train loss: -0.1538, rewards: 0.5413\n",
      "Iteration: 600, Train loss: -0.0699, rewards: 0.4637\n",
      "Eval:\n",
      "Hits@1: 0.6004, Hits@3: 0.6556, Hits@10: 0.6814, MRR: 0.6323\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.0609, rewards: 0.5419\n",
      "Iteration: 620, Train loss: -0.2567, rewards: 0.4537\n",
      "Iteration: 630, Train loss: -0.2152, rewards: 0.6188\n",
      "Iteration: 640, Train loss: -0.1715, rewards: 0.4313\n",
      "Iteration: 650, Train loss: -0.1500, rewards: 0.4963\n",
      "Iteration: 660, Train loss: -0.0674, rewards: 0.4869\n",
      "Iteration: 670, Train loss: -0.1575, rewards: 0.4825\n",
      "Iteration: 680, Train loss: -0.1784, rewards: 0.5444\n",
      "Iteration: 690, Train loss: -0.1041, rewards: 0.4100\n",
      "Iteration: 700, Train loss: -0.2428, rewards: 0.4169\n",
      "Eval:\n",
      "Hits@1: 0.5691, Hits@3: 0.6409, Hits@10: 0.6759, MRR: 0.6101\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.3129, rewards: 0.5238\n",
      "Iteration: 720, Train loss: -0.0815, rewards: 0.5312\n",
      "Iteration: 730, Train loss: -0.2197, rewards: 0.4850\n",
      "Iteration: 740, Train loss: -0.2285, rewards: 0.6375\n",
      "Iteration: 750, Train loss: -0.2533, rewards: 0.5156\n",
      "Iteration: 760, Train loss: -0.1872, rewards: 0.4706\n",
      "Iteration: 770, Train loss: -0.2332, rewards: 0.5012\n",
      "Iteration: 780, Train loss: -0.2489, rewards: 0.5244\n",
      "Iteration: 790, Train loss: -0.1192, rewards: 0.6056\n",
      "Iteration: 800, Train loss: -0.3271, rewards: 0.4906\n",
      "Eval:\n",
      "Hits@1: 0.5875, Hits@3: 0.6575, Hits@10: 0.6814, MRR: 0.6247\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.2142, rewards: 0.4306\n",
      "Iteration: 820, Train loss: -0.1622, rewards: 0.4437\n",
      "Iteration: 830, Train loss: -0.2616, rewards: 0.5544\n",
      "Iteration: 840, Train loss: -0.3264, rewards: 0.4594\n",
      "Iteration: 850, Train loss: -0.3025, rewards: 0.5669\n",
      "Iteration: 860, Train loss: -0.1855, rewards: 0.5744\n",
      "Iteration: 870, Train loss: -0.1881, rewards: 0.4300\n",
      "Iteration: 880, Train loss: -0.3241, rewards: 0.5088\n",
      "Iteration: 890, Train loss: -0.3520, rewards: 0.5119\n",
      "Iteration: 900, Train loss: -0.1815, rewards: 0.5394\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6538, Hits@10: 0.6851, MRR: 0.6213\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.2908, rewards: 0.5631\n",
      "Iteration: 920, Train loss: -0.2670, rewards: 0.5262\n",
      "Iteration: 930, Train loss: -0.1412, rewards: 0.3969\n",
      "Iteration: 940, Train loss: -0.1245, rewards: 0.4056\n",
      "Iteration: 950, Train loss: -0.0297, rewards: 0.5038\n",
      "Iteration: 960, Train loss: -0.0374, rewards: 0.5200\n",
      "Iteration: 970, Train loss: -0.2840, rewards: 0.5225\n",
      "Iteration: 980, Train loss: -0.0904, rewards: 0.4975\n",
      "Iteration: 990, Train loss: -0.3110, rewards: 0.5931\n",
      "Iteration: 1000, Train loss: -0.1452, rewards: 0.5856\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6390, Hits@10: 0.6777, MRR: 0.6188\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.1925, rewards: 0.4938\n",
      "Iteration: 1020, Train loss: -0.2177, rewards: 0.5587\n",
      "Iteration: 1030, Train loss: -0.1990, rewards: 0.5312\n",
      "Iteration: 1040, Train loss: -0.1329, rewards: 0.5125\n",
      "Iteration: 1050, Train loss: -0.0779, rewards: 0.5706\n",
      "Iteration: 1060, Train loss: -0.2235, rewards: 0.6006\n",
      "Iteration: 1070, Train loss: -0.1822, rewards: 0.4950\n",
      "Iteration: 1080, Train loss: -0.2333, rewards: 0.6525\n",
      "Iteration: 1090, Train loss: -0.0139, rewards: 0.6406\n",
      "Iteration: 1100, Train loss: -0.2090, rewards: 0.5425\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6446, Hits@10: 0.6888, MRR: 0.6205\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.1910, rewards: 0.5713\n",
      "Iteration: 1120, Train loss: -0.1759, rewards: 0.5737\n",
      "Iteration: 1130, Train loss: -0.2458, rewards: 0.5025\n",
      "Iteration: 1140, Train loss: -0.0345, rewards: 0.5606\n",
      "Iteration: 1150, Train loss: -0.1087, rewards: 0.5406\n",
      "Iteration: 1160, Train loss: -0.0540, rewards: 0.5406\n",
      "Iteration: 1170, Train loss: -0.1634, rewards: 0.4562\n",
      "Iteration: 1180, Train loss: -0.1975, rewards: 0.6569\n",
      "Iteration: 1190, Train loss: -0.2715, rewards: 0.5125\n",
      "Iteration: 1200, Train loss: -0.0674, rewards: 0.5913\n",
      "Eval:\n",
      "Hits@1: 0.6059, Hits@3: 0.6538, Hits@10: 0.6888, MRR: 0.6351\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.1486, rewards: 0.5537\n",
      "Iteration: 1220, Train loss: -0.0908, rewards: 0.5500\n",
      "Iteration: 1230, Train loss: -0.1641, rewards: 0.6231\n",
      "Iteration: 1240, Train loss: -0.1241, rewards: 0.5319\n",
      "Iteration: 1250, Train loss: -0.1348, rewards: 0.6412\n",
      "Iteration: 1260, Train loss: -0.1874, rewards: 0.5500\n",
      "Iteration: 1270, Train loss: -0.2195, rewards: 0.6281\n",
      "Iteration: 1280, Train loss: -0.2037, rewards: 0.5600\n",
      "Iteration: 1290, Train loss: -0.0443, rewards: 0.5956\n",
      "Iteration: 1300, Train loss: -0.1095, rewards: 0.5587\n",
      "Eval:\n",
      "Hits@1: 0.6041, Hits@3: 0.6575, Hits@10: 0.6869, MRR: 0.6345\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: -0.0429, rewards: 0.5088\n",
      "Iteration: 1320, Train loss: -0.1557, rewards: 0.4981\n",
      "Iteration: 1330, Train loss: -0.0772, rewards: 0.4950\n",
      "Iteration: 1340, Train loss: -0.1547, rewards: 0.6000\n",
      "Iteration: 1350, Train loss: -0.2403, rewards: 0.5962\n",
      "Iteration: 1360, Train loss: -0.1357, rewards: 0.5138\n",
      "Iteration: 1370, Train loss: -0.1911, rewards: 0.5144\n",
      "Iteration: 1380, Train loss: -0.1413, rewards: 0.5506\n",
      "Iteration: 1390, Train loss: -0.1688, rewards: 0.4512\n",
      "Iteration: 1400, Train loss: -0.1738, rewards: 0.6025\n",
      "Eval:\n",
      "Hits@1: 0.5875, Hits@3: 0.6575, Hits@10: 0.6869, MRR: 0.6260\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.1969, rewards: 0.5463\n",
      "Iteration: 1420, Train loss: 0.0565, rewards: 0.5325\n",
      "Iteration: 1430, Train loss: -0.2073, rewards: 0.6319\n",
      "Iteration: 1440, Train loss: -0.1560, rewards: 0.5938\n",
      "Iteration: 1450, Train loss: -0.1228, rewards: 0.5931\n",
      "Iteration: 1460, Train loss: 0.0447, rewards: 0.5525\n",
      "Iteration: 1470, Train loss: -0.1294, rewards: 0.6388\n",
      "Iteration: 1480, Train loss: -0.2529, rewards: 0.5494\n",
      "Iteration: 1490, Train loss: -0.0022, rewards: 0.5825\n",
      "Iteration: 1500, Train loss: -0.2922, rewards: 0.5694\n",
      "Eval:\n",
      "Hits@1: 0.5875, Hits@3: 0.6483, Hits@10: 0.6832, MRR: 0.6228\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.1733, rewards: 0.5269\n",
      "Iteration: 1520, Train loss: -0.0263, rewards: 0.5294\n",
      "Iteration: 1530, Train loss: -0.0978, rewards: 0.5325\n",
      "Iteration: 1540, Train loss: -0.1457, rewards: 0.5913\n",
      "Iteration: 1550, Train loss: -0.0016, rewards: 0.5650\n",
      "Iteration: 1560, Train loss: -0.1790, rewards: 0.6081\n",
      "Iteration: 1570, Train loss: -0.1512, rewards: 0.4813\n",
      "Iteration: 1580, Train loss: -0.1843, rewards: 0.6550\n",
      "Iteration: 1590, Train loss: -0.0527, rewards: 0.6138\n",
      "Iteration: 1600, Train loss: -0.2095, rewards: 0.5513\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6464, Hits@10: 0.6814, MRR: 0.6205\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.0485, rewards: 0.5519\n",
      "Iteration: 1620, Train loss: -0.1591, rewards: 0.5938\n",
      "Iteration: 1630, Train loss: 0.0037, rewards: 0.5881\n",
      "Iteration: 1640, Train loss: -0.0045, rewards: 0.5713\n",
      "Iteration: 1650, Train loss: -0.1699, rewards: 0.6813\n",
      "Iteration: 1660, Train loss: -0.0036, rewards: 0.5369\n",
      "Iteration: 1670, Train loss: -0.2333, rewards: 0.5587\n",
      "Iteration: 1680, Train loss: -0.1814, rewards: 0.4975\n",
      "Iteration: 1690, Train loss: -0.2729, rewards: 0.5537\n",
      "Iteration: 1700, Train loss: -0.0457, rewards: 0.5525\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6538, Hits@10: 0.6777, MRR: 0.6200\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.1592, rewards: 0.5506\n",
      "Iteration: 1720, Train loss: -0.2092, rewards: 0.5144\n",
      "Iteration: 1730, Train loss: 0.1253, rewards: 0.5200\n",
      "Iteration: 1740, Train loss: -0.0742, rewards: 0.6325\n",
      "Iteration: 1750, Train loss: -0.1054, rewards: 0.6919\n",
      "Iteration: 1760, Train loss: -0.2051, rewards: 0.6006\n",
      "Iteration: 1770, Train loss: -0.0201, rewards: 0.6150\n",
      "Iteration: 1780, Train loss: -0.0439, rewards: 0.5175\n",
      "Iteration: 1790, Train loss: -0.2225, rewards: 0.5256\n",
      "Iteration: 1800, Train loss: -0.0998, rewards: 0.5500\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.6096, Hits@3: 0.6630, Hits@10: 0.6906, MRR: 0.6389\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.1927, rewards: 0.5644\n",
      "Iteration: 1820, Train loss: -0.1118, rewards: 0.5256\n",
      "Iteration: 1830, Train loss: -0.1485, rewards: 0.4594\n",
      "Iteration: 1840, Train loss: -0.0961, rewards: 0.6831\n",
      "Iteration: 1850, Train loss: -0.1302, rewards: 0.6031\n",
      "Iteration: 1860, Train loss: 0.0289, rewards: 0.5169\n",
      "Iteration: 1870, Train loss: -0.0197, rewards: 0.6544\n",
      "Iteration: 1880, Train loss: -0.1384, rewards: 0.5575\n",
      "Iteration: 1890, Train loss: -0.1673, rewards: 0.6231\n",
      "Iteration: 1900, Train loss: -0.1163, rewards: 0.6687\n",
      "Eval:\n",
      "Hits@1: 0.5967, Hits@3: 0.6575, Hits@10: 0.6869, MRR: 0.6300\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.1377, rewards: 0.5112\n",
      "Iteration: 1920, Train loss: -0.1667, rewards: 0.5988\n",
      "Iteration: 1930, Train loss: -0.0341, rewards: 0.5581\n",
      "Iteration: 1940, Train loss: -0.1029, rewards: 0.6500\n",
      "Iteration: 1950, Train loss: -0.0518, rewards: 0.5344\n",
      "Iteration: 1960, Train loss: 0.1024, rewards: 0.4975\n",
      "Iteration: 1970, Train loss: -0.1095, rewards: 0.5594\n",
      "Iteration: 1980, Train loss: -0.1901, rewards: 0.5119\n",
      "Iteration: 1990, Train loss: -0.0838, rewards: 0.5819\n",
      "Iteration: 2000, Train loss: -0.3228, rewards: 0.5650\n",
      "Eval:\n",
      "Hits@1: 0.6151, Hits@3: 0.6538, Hits@10: 0.6888, MRR: 0.6388\n",
      "------------------------------------------------------------\n",
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Hits@1: 0.6934, Hits@3: 0.8048, Hits@10: 0.8570, MRR: 0.7556\n",
      "4\n",
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n",
      "Iteration: 10, Train loss: -0.3441, rewards: 0.1306\n",
      "Iteration: 20, Train loss: -0.4138, rewards: 0.1306\n",
      "Iteration: 30, Train loss: -0.3130, rewards: 0.1431\n",
      "Iteration: 40, Train loss: -0.2773, rewards: 0.2044\n",
      "Iteration: 50, Train loss: -0.2413, rewards: 0.2162\n",
      "Iteration: 60, Train loss: -0.2948, rewards: 0.2869\n",
      "Iteration: 70, Train loss: -0.3701, rewards: 0.3287\n",
      "Iteration: 80, Train loss: -0.2614, rewards: 0.2850\n",
      "Iteration: 90, Train loss: -0.2608, rewards: 0.4050\n",
      "Iteration: 100, Train loss: -0.4347, rewards: 0.3644\n",
      "Eval:\n",
      "Hits@1: 0.4236, Hits@3: 0.5212, Hits@10: 0.6004, MRR: 0.4834\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.3437, rewards: 0.3756\n",
      "Iteration: 120, Train loss: -0.3464, rewards: 0.3819\n",
      "Iteration: 130, Train loss: -0.4041, rewards: 0.4281\n",
      "Iteration: 140, Train loss: -0.2003, rewards: 0.3613\n",
      "Iteration: 150, Train loss: -0.2254, rewards: 0.3169\n",
      "Iteration: 160, Train loss: -0.3315, rewards: 0.3319\n",
      "Iteration: 170, Train loss: -0.4806, rewards: 0.4050\n",
      "Iteration: 180, Train loss: -0.2658, rewards: 0.3762\n",
      "Iteration: 190, Train loss: -0.4646, rewards: 0.3113\n",
      "Iteration: 200, Train loss: -0.3870, rewards: 0.3569\n",
      "Eval:\n",
      "Hits@1: 0.5046, Hits@3: 0.5967, Hits@10: 0.6667, MRR: 0.5598\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.2952, rewards: 0.3706\n",
      "Iteration: 220, Train loss: -0.4450, rewards: 0.4275\n",
      "Iteration: 230, Train loss: -0.2210, rewards: 0.3463\n",
      "Iteration: 240, Train loss: -0.2137, rewards: 0.3975\n",
      "Iteration: 250, Train loss: -0.2555, rewards: 0.5225\n",
      "Iteration: 260, Train loss: -0.1936, rewards: 0.4756\n",
      "Iteration: 270, Train loss: -0.4685, rewards: 0.4188\n",
      "Iteration: 280, Train loss: -0.3011, rewards: 0.4088\n",
      "Iteration: 290, Train loss: -0.2451, rewards: 0.4206\n",
      "Iteration: 300, Train loss: -0.4046, rewards: 0.3862\n",
      "Eval:\n",
      "Hits@1: 0.5212, Hits@3: 0.6114, Hits@10: 0.6722, MRR: 0.5732\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.2162, rewards: 0.4106\n",
      "Iteration: 320, Train loss: -0.1616, rewards: 0.4381\n",
      "Iteration: 330, Train loss: -0.3497, rewards: 0.4244\n",
      "Iteration: 340, Train loss: -0.0983, rewards: 0.4587\n",
      "Iteration: 350, Train loss: -0.0899, rewards: 0.4956\n",
      "Iteration: 360, Train loss: -0.2446, rewards: 0.4062\n",
      "Iteration: 370, Train loss: -0.2745, rewards: 0.4888\n",
      "Iteration: 380, Train loss: -0.2683, rewards: 0.5319\n",
      "Iteration: 390, Train loss: -0.2303, rewards: 0.5031\n",
      "Iteration: 400, Train loss: -0.1908, rewards: 0.4188\n",
      "Eval:\n",
      "Hits@1: 0.5543, Hits@3: 0.6409, Hits@10: 0.6832, MRR: 0.6031\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.2286, rewards: 0.4844\n",
      "Iteration: 420, Train loss: -0.0744, rewards: 0.4456\n",
      "Iteration: 430, Train loss: -0.0342, rewards: 0.5038\n",
      "Iteration: 440, Train loss: -0.1831, rewards: 0.5331\n",
      "Iteration: 450, Train loss: -0.2334, rewards: 0.4944\n",
      "Iteration: 460, Train loss: -0.3351, rewards: 0.5575\n",
      "Iteration: 470, Train loss: -0.1118, rewards: 0.6269\n",
      "Iteration: 480, Train loss: -0.0839, rewards: 0.4250\n",
      "Iteration: 490, Train loss: -0.1953, rewards: 0.5300\n",
      "Iteration: 500, Train loss: -0.2758, rewards: 0.5475\n",
      "Eval:\n",
      "Hits@1: 0.5709, Hits@3: 0.6501, Hits@10: 0.6832, MRR: 0.6152\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.1311, rewards: 0.5656\n",
      "Iteration: 520, Train loss: -0.2045, rewards: 0.5125\n",
      "Iteration: 530, Train loss: -0.2692, rewards: 0.5356\n",
      "Iteration: 540, Train loss: -0.1099, rewards: 0.5437\n",
      "Iteration: 550, Train loss: -0.2724, rewards: 0.4981\n",
      "Iteration: 560, Train loss: -0.3368, rewards: 0.4544\n",
      "Iteration: 570, Train loss: -0.2531, rewards: 0.4706\n",
      "Iteration: 580, Train loss: -0.2816, rewards: 0.4669\n",
      "Iteration: 590, Train loss: -0.2889, rewards: 0.4831\n",
      "Iteration: 600, Train loss: -0.3521, rewards: 0.5356\n",
      "Eval:\n",
      "Hits@1: 0.5672, Hits@3: 0.6630, Hits@10: 0.6906, MRR: 0.6168\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.2204, rewards: 0.4881\n",
      "Iteration: 620, Train loss: -0.2370, rewards: 0.4869\n",
      "Iteration: 630, Train loss: -0.2140, rewards: 0.6056\n",
      "Iteration: 640, Train loss: -0.2949, rewards: 0.4731\n",
      "Iteration: 650, Train loss: -0.1006, rewards: 0.4288\n",
      "Iteration: 660, Train loss: -0.2981, rewards: 0.5250\n",
      "Iteration: 670, Train loss: -0.3702, rewards: 0.5050\n",
      "Iteration: 680, Train loss: -0.2217, rewards: 0.5219\n",
      "Iteration: 690, Train loss: -0.0503, rewards: 0.4631\n",
      "Iteration: 700, Train loss: -0.2586, rewards: 0.5206\n",
      "Eval:\n",
      "Hits@1: 0.5746, Hits@3: 0.6538, Hits@10: 0.6924, MRR: 0.6184\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.1117, rewards: 0.4744\n",
      "Iteration: 720, Train loss: -0.2700, rewards: 0.4462\n",
      "Iteration: 730, Train loss: -0.4077, rewards: 0.5181\n",
      "Iteration: 740, Train loss: -0.2320, rewards: 0.4587\n",
      "Iteration: 750, Train loss: -0.4722, rewards: 0.4631\n",
      "Iteration: 760, Train loss: -0.2511, rewards: 0.5844\n",
      "Iteration: 770, Train loss: -0.1684, rewards: 0.5356\n",
      "Iteration: 780, Train loss: -0.2117, rewards: 0.5056\n",
      "Iteration: 790, Train loss: -0.1913, rewards: 0.4719\n",
      "Iteration: 800, Train loss: -0.1601, rewards: 0.5281\n",
      "Eval:\n",
      "Hits@1: 0.5912, Hits@3: 0.6685, Hits@10: 0.6906, MRR: 0.6304\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.0119, rewards: 0.4950\n",
      "Iteration: 820, Train loss: -0.2116, rewards: 0.5525\n",
      "Iteration: 830, Train loss: -0.2575, rewards: 0.5544\n",
      "Iteration: 840, Train loss: -0.2766, rewards: 0.5619\n",
      "Iteration: 850, Train loss: -0.2464, rewards: 0.6225\n",
      "Iteration: 860, Train loss: -0.2089, rewards: 0.5044\n",
      "Iteration: 870, Train loss: -0.1397, rewards: 0.5600\n",
      "Iteration: 880, Train loss: -0.4067, rewards: 0.6419\n",
      "Iteration: 890, Train loss: -0.1811, rewards: 0.5363\n",
      "Iteration: 900, Train loss: -0.1806, rewards: 0.5256\n",
      "Eval:\n",
      "Hits@1: 0.5912, Hits@3: 0.6630, Hits@10: 0.6869, MRR: 0.6304\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.2499, rewards: 0.5062\n",
      "Iteration: 920, Train loss: -0.1264, rewards: 0.4694\n",
      "Iteration: 930, Train loss: -0.2966, rewards: 0.5756\n",
      "Iteration: 940, Train loss: -0.1332, rewards: 0.5250\n",
      "Iteration: 950, Train loss: -0.1845, rewards: 0.5031\n",
      "Iteration: 960, Train loss: -0.0432, rewards: 0.5500\n",
      "Iteration: 970, Train loss: -0.1338, rewards: 0.5256\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 980, Train loss: -0.3303, rewards: 0.5913\n",
      "Iteration: 990, Train loss: -0.2739, rewards: 0.5644\n",
      "Iteration: 1000, Train loss: -0.4958, rewards: 0.6412\n",
      "Eval:\n",
      "Hits@1: 0.5764, Hits@3: 0.6483, Hits@10: 0.6796, MRR: 0.6168\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.2973, rewards: 0.5106\n",
      "Iteration: 1020, Train loss: -0.2418, rewards: 0.6075\n",
      "Iteration: 1030, Train loss: -0.1840, rewards: 0.4050\n",
      "Iteration: 1040, Train loss: -0.0430, rewards: 0.5600\n",
      "Iteration: 1050, Train loss: -0.2345, rewards: 0.5144\n",
      "Iteration: 1060, Train loss: -0.1743, rewards: 0.4394\n",
      "Iteration: 1070, Train loss: -0.1901, rewards: 0.5288\n",
      "Iteration: 1080, Train loss: -0.1170, rewards: 0.5631\n",
      "Iteration: 1090, Train loss: -0.2108, rewards: 0.5731\n",
      "Iteration: 1100, Train loss: -0.1084, rewards: 0.5206\n",
      "Eval:\n",
      "Hits@1: 0.5985, Hits@3: 0.6703, Hits@10: 0.6888, MRR: 0.6359\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.2256, rewards: 0.4856\n",
      "Iteration: 1120, Train loss: -0.1980, rewards: 0.5750\n",
      "Iteration: 1130, Train loss: -0.1802, rewards: 0.5669\n",
      "Iteration: 1140, Train loss: -0.2017, rewards: 0.5487\n",
      "Iteration: 1150, Train loss: -0.1584, rewards: 0.5606\n",
      "Iteration: 1160, Train loss: -0.2088, rewards: 0.5450\n",
      "Iteration: 1170, Train loss: -0.1318, rewards: 0.4681\n",
      "Iteration: 1180, Train loss: -0.1243, rewards: 0.5056\n",
      "Iteration: 1190, Train loss: -0.1834, rewards: 0.5725\n",
      "Iteration: 1200, Train loss: -0.0586, rewards: 0.4856\n",
      "Eval:\n",
      "Hits@1: 0.5820, Hits@3: 0.6593, Hits@10: 0.6943, MRR: 0.6269\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.0411, rewards: 0.5719\n",
      "Iteration: 1220, Train loss: -0.1117, rewards: 0.5069\n",
      "Iteration: 1230, Train loss: -0.2963, rewards: 0.4769\n",
      "Iteration: 1240, Train loss: -0.2699, rewards: 0.5700\n",
      "Iteration: 1250, Train loss: -0.0534, rewards: 0.6675\n",
      "Iteration: 1260, Train loss: -0.1712, rewards: 0.4988\n",
      "Iteration: 1270, Train loss: -0.0826, rewards: 0.5044\n",
      "Iteration: 1280, Train loss: -0.1037, rewards: 0.6038\n",
      "Iteration: 1290, Train loss: -0.1684, rewards: 0.5663\n",
      "Iteration: 1300, Train loss: -0.2115, rewards: 0.6194\n",
      "Eval:\n",
      "Hits@1: 0.5948, Hits@3: 0.6740, Hits@10: 0.6924, MRR: 0.6369\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: 0.0019, rewards: 0.5238\n",
      "Iteration: 1320, Train loss: -0.3128, rewards: 0.5406\n",
      "Iteration: 1330, Train loss: -0.1659, rewards: 0.5756\n",
      "Iteration: 1340, Train loss: -0.0647, rewards: 0.5256\n",
      "Iteration: 1350, Train loss: -0.2176, rewards: 0.5400\n",
      "Iteration: 1360, Train loss: -0.3484, rewards: 0.5475\n",
      "Iteration: 1370, Train loss: -0.3034, rewards: 0.6381\n",
      "Iteration: 1380, Train loss: -0.0654, rewards: 0.4662\n",
      "Iteration: 1390, Train loss: -0.3018, rewards: 0.5413\n",
      "Iteration: 1400, Train loss: -0.1487, rewards: 0.4394\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6593, Hits@10: 0.6869, MRR: 0.6266\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.1524, rewards: 0.4819\n",
      "Iteration: 1420, Train loss: -0.2445, rewards: 0.5813\n",
      "Iteration: 1430, Train loss: -0.1028, rewards: 0.5100\n",
      "Iteration: 1440, Train loss: -0.2660, rewards: 0.5644\n",
      "Iteration: 1450, Train loss: -0.1495, rewards: 0.6425\n",
      "Iteration: 1460, Train loss: -0.3061, rewards: 0.5062\n",
      "Iteration: 1470, Train loss: -0.1646, rewards: 0.6162\n",
      "Iteration: 1480, Train loss: -0.0837, rewards: 0.4813\n",
      "Iteration: 1490, Train loss: -0.2700, rewards: 0.5575\n",
      "Iteration: 1500, Train loss: -0.1755, rewards: 0.5719\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6667, Hits@10: 0.6924, MRR: 0.6323\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.0602, rewards: 0.5106\n",
      "Iteration: 1520, Train loss: -0.1796, rewards: 0.5406\n",
      "Iteration: 1530, Train loss: -0.0213, rewards: 0.4894\n",
      "Iteration: 1540, Train loss: -0.1943, rewards: 0.5606\n",
      "Iteration: 1550, Train loss: -0.3041, rewards: 0.4931\n",
      "Iteration: 1560, Train loss: -0.1147, rewards: 0.5969\n",
      "Iteration: 1570, Train loss: -0.0802, rewards: 0.5125\n",
      "Iteration: 1580, Train loss: -0.1930, rewards: 0.5737\n",
      "Iteration: 1590, Train loss: 0.0137, rewards: 0.5350\n",
      "Iteration: 1600, Train loss: -0.2439, rewards: 0.5825\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6593, Hits@10: 0.6888, MRR: 0.6251\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.1295, rewards: 0.5200\n",
      "Iteration: 1620, Train loss: -0.1253, rewards: 0.5194\n",
      "Iteration: 1630, Train loss: -0.2655, rewards: 0.5825\n",
      "Iteration: 1640, Train loss: 0.0218, rewards: 0.6338\n",
      "Iteration: 1650, Train loss: -0.2457, rewards: 0.4956\n",
      "Iteration: 1660, Train loss: -0.0313, rewards: 0.5031\n",
      "Iteration: 1670, Train loss: 0.0811, rewards: 0.5525\n",
      "Iteration: 1680, Train loss: 0.0938, rewards: 0.4725\n",
      "Iteration: 1690, Train loss: 0.0068, rewards: 0.6200\n",
      "Iteration: 1700, Train loss: -0.0310, rewards: 0.6262\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6575, Hits@10: 0.6851, MRR: 0.6248\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.2263, rewards: 0.7081\n",
      "Iteration: 1720, Train loss: -0.1530, rewards: 0.5644\n",
      "Iteration: 1730, Train loss: -0.1007, rewards: 0.6088\n",
      "Iteration: 1740, Train loss: -0.1367, rewards: 0.6344\n",
      "Iteration: 1750, Train loss: -0.1898, rewards: 0.6412\n",
      "Iteration: 1760, Train loss: -0.1003, rewards: 0.5856\n",
      "Iteration: 1770, Train loss: -0.0793, rewards: 0.5081\n",
      "Iteration: 1780, Train loss: -0.0148, rewards: 0.4813\n",
      "Iteration: 1790, Train loss: -0.2551, rewards: 0.5750\n",
      "Iteration: 1800, Train loss: -0.0820, rewards: 0.4769\n",
      "Eval:\n",
      "Hits@1: 0.6041, Hits@3: 0.6667, Hits@10: 0.6943, MRR: 0.6395\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.0701, rewards: 0.4700\n",
      "Iteration: 1820, Train loss: -0.2408, rewards: 0.6056\n",
      "Iteration: 1830, Train loss: -0.0910, rewards: 0.5581\n",
      "Iteration: 1840, Train loss: -0.1525, rewards: 0.5625\n",
      "Iteration: 1850, Train loss: -0.0109, rewards: 0.5513\n",
      "Iteration: 1860, Train loss: -0.0625, rewards: 0.6038\n",
      "Iteration: 1870, Train loss: -0.3085, rewards: 0.5806\n",
      "Iteration: 1880, Train loss: -0.1395, rewards: 0.6687\n",
      "Iteration: 1890, Train loss: -0.0944, rewards: 0.6944\n",
      "Iteration: 1900, Train loss: -0.0966, rewards: 0.6306\n",
      "Eval:\n",
      "Hits@1: 0.5985, Hits@3: 0.6759, Hits@10: 0.6961, MRR: 0.6395\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.2251, rewards: 0.6256\n",
      "Iteration: 1920, Train loss: -0.0353, rewards: 0.5719\n",
      "Iteration: 1930, Train loss: -0.1086, rewards: 0.5856\n",
      "Iteration: 1940, Train loss: -0.1670, rewards: 0.5988\n",
      "Iteration: 1950, Train loss: -0.2028, rewards: 0.6869\n",
      "Iteration: 1960, Train loss: -0.1743, rewards: 0.5400\n",
      "Iteration: 1970, Train loss: -0.2226, rewards: 0.6044\n",
      "Iteration: 1980, Train loss: 0.0034, rewards: 0.5706\n",
      "Iteration: 1990, Train loss: -0.2168, rewards: 0.5400\n",
      "Iteration: 2000, Train loss: -0.0878, rewards: 0.5650\n",
      "Eval:\n",
      "Hits@1: 0.6114, Hits@3: 0.6685, Hits@10: 0.6961, MRR: 0.6452\n",
      "------------------------------------------------------------\n",
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Hits@1: 0.6781, Hits@3: 0.7942, Hits@10: 0.8510, MRR: 0.7431\n"
     ]
    }
   ],
   "source": [
    "results = {}\n",
    "for layer in [1, 3, 4]:\n",
    "    \n",
    "    print(layer)\n",
    "    options['gnn_layer'] = layer\n",
    "    options['model_dir'] = './outputs_nell995-1/'\n",
    "    options['output_dir'] = './outputs_nell995-1/'\n",
    "    options['model_dir'] = options['model_dir'] + str(layer) + 'layer' + '/'\n",
    "    options['output_dir'] = options['output_dir'] + str(layer) + 'layer' + '/'\n",
    "    \n",
    "    if not os.path.exists(options['model_dir']):\n",
    "        os.mkdir(options['model_dir'])\n",
    "    if not os.path.exists(options['output_dir']):\n",
    "        os.mkdir(options['output_dir'])\n",
    "\n",
    "    trainer = Trainer(options)\n",
    "    trainer.train()\n",
    "\n",
    "    options['test_rollouts'] = 100\n",
    "    options['max_num_actions'] = 100\n",
    "    options['eval_batch_size'] = 8\n",
    "    tester = Trainer(options)\n",
    "    tester.agent.load_state_dict(torch.load(options['model_dir'] + 'agent.ckpt'))\n",
    "    tester.agent.eval()\n",
    "    tester.test_environment = tester.test_test_environment\n",
    "    test_results = tester.test(beam=True, print_paths=False, save_model=False)\n",
    "\n",
    "    results[layer] = deepcopy(test_results)\n",
    "    \n",
    "    with open(options['model_dir'] + '../layer_results.pk5', 'wb') as f:\n",
    "        pickle5.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ca3d2db1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: 'Hits@1: 0.6845, Hits@3: 0.7970, Hits@10: 0.8502, MRR: 0.7473',\n",
       " 3: 'Hits@1: 0.6934, Hits@3: 0.8048, Hits@10: 0.8570, MRR: 0.7556',\n",
       " 4: 'Hits@1: 0.6781, Hits@3: 0.7942, Hits@10: 0.8510, MRR: 0.7431'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8a8a731",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
