{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2e28869",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_params(LSTM_layers = 1, batch_size = 8, beta = 0.15, Lambda = 0.15, learning_rate = 5e-5):\n",
    "    options = {}\n",
    "\n",
    "    #basic setting\n",
    "    options['use_cuda'] = True\n",
    "    options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/nell/vocab/'\n",
    "    options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/nell/'\n",
    "    options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "    options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "    options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "    options['model_dir'] = './outputs_NELL-995_v7-tune/'\n",
    "    options['output_dir'] = './outputs_NELL-995_v7-tune/'\n",
    "\n",
    "    #agent setting\n",
    "    options['pretrained_embeddings_relation'] = {}\n",
    "    options['pretrained_embeddings_entity'] = {}\n",
    "    options['embedding_size'] = 50\n",
    "    options['hidden_size'] = 200\n",
    "    options['use_entity_embeddings'] = 1\n",
    "    options['train_entity_embeddings'] = 0\n",
    "    options['train_relation_embeddings'] = 1\n",
    "    options['path_length'] = 3\n",
    "    options['LSTM_layers'] = LSTM_layers\n",
    "    options['max_num_actions'] = 40\n",
    "\n",
    "    #hyperparameters\n",
    "    options['test_rollouts'] = 40\n",
    "    options['num_rollouts'] = 20\n",
    "    options['batch_size'] = batch_size\n",
    "    options['eval_batch_size'] = 32\n",
    "    options['beta'] = beta\n",
    "    options['Lambda'] = Lambda\n",
    "    options['gamma'] = 1\n",
    "    options['positive_reward'] = 1\n",
    "    options['negative_reward'] = 0\n",
    "    options['learning_rate'] = learning_rate\n",
    "    options['grad_clip_norm'] = 100\n",
    "    options['eval_every'] = 100\n",
    "    options['total_iterations'] = 2000*(64/batch_size)\n",
    "    options['pool'] = 'max'\n",
    "    \n",
    "    return options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc64eafe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/Ours/model/ours.py:333: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.2211, rewards: 0.1023\n",
      "Iteration: 20, Train loss: -0.2648, rewards: 0.1505\n",
      "Iteration: 30, Train loss: -0.2404, rewards: 0.2039\n",
      "Iteration: 40, Train loss: -0.2208, rewards: 0.2736\n",
      "Iteration: 50, Train loss: -0.1568, rewards: 0.2787\n",
      "Iteration: 60, Train loss: -0.1901, rewards: 0.3400\n",
      "Iteration: 70, Train loss: -0.1360, rewards: 0.3598\n",
      "Iteration: 80, Train loss: -0.2420, rewards: 0.3277\n",
      "Iteration: 90, Train loss: -0.2126, rewards: 0.3808\n",
      "Iteration: 100, Train loss: -0.2255, rewards: 0.3744\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/Research/GraphRL/Ours/model/ours.py:635: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.4917, Hits@3: 0.5838, Hits@10: 0.6262, MRR: 0.5385\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.1748, rewards: 0.4217\n",
      "Iteration: 120, Train loss: -0.1777, rewards: 0.4888\n",
      "Iteration: 130, Train loss: -0.1935, rewards: 0.4283\n",
      "Iteration: 140, Train loss: -0.1644, rewards: 0.5047\n",
      "Iteration: 150, Train loss: -0.1654, rewards: 0.4461\n",
      "Iteration: 160, Train loss: -0.1630, rewards: 0.4470\n",
      "Iteration: 170, Train loss: -0.2140, rewards: 0.4989\n",
      "Iteration: 180, Train loss: -0.1592, rewards: 0.4764\n",
      "Iteration: 190, Train loss: -0.1058, rewards: 0.4655\n",
      "Iteration: 200, Train loss: -0.1020, rewards: 0.4642\n",
      "Eval:\n",
      "Hits@1: 0.5562, Hits@3: 0.6354, Hits@10: 0.6667, MRR: 0.5961\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.1424, rewards: 0.5214\n",
      "Iteration: 220, Train loss: -0.1115, rewards: 0.4788\n",
      "Iteration: 230, Train loss: -0.1156, rewards: 0.5000\n",
      "Iteration: 240, Train loss: -0.1266, rewards: 0.4581\n",
      "Iteration: 250, Train loss: -0.1240, rewards: 0.5261\n",
      "Iteration: 260, Train loss: -0.0940, rewards: 0.4972\n",
      "Iteration: 270, Train loss: -0.1466, rewards: 0.4966\n",
      "Iteration: 280, Train loss: -0.1670, rewards: 0.4873\n",
      "Iteration: 290, Train loss: -0.1362, rewards: 0.4955\n",
      "Iteration: 300, Train loss: -0.0922, rewards: 0.4902\n",
      "Eval:\n",
      "Hits@1: 0.5617, Hits@3: 0.6372, Hits@10: 0.6630, MRR: 0.6007\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.1291, rewards: 0.5586\n",
      "Iteration: 320, Train loss: -0.0896, rewards: 0.5091\n",
      "Iteration: 330, Train loss: -0.0614, rewards: 0.5530\n",
      "Iteration: 340, Train loss: -0.1232, rewards: 0.5094\n",
      "Iteration: 350, Train loss: -0.1020, rewards: 0.4744\n",
      "Iteration: 360, Train loss: -0.0563, rewards: 0.4798\n",
      "Iteration: 370, Train loss: -0.0992, rewards: 0.5061\n",
      "Iteration: 380, Train loss: -0.1254, rewards: 0.5627\n",
      "Iteration: 390, Train loss: -0.0961, rewards: 0.5175\n",
      "Iteration: 400, Train loss: -0.0946, rewards: 0.5350\n",
      "Eval:\n",
      "Hits@1: 0.5691, Hits@3: 0.6372, Hits@10: 0.6630, MRR: 0.6056\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.0538, rewards: 0.5378\n",
      "Iteration: 420, Train loss: -0.0372, rewards: 0.5248\n",
      "Iteration: 430, Train loss: -0.1284, rewards: 0.5256\n",
      "Iteration: 440, Train loss: -0.1296, rewards: 0.4847\n",
      "Iteration: 450, Train loss: -0.0854, rewards: 0.5173\n",
      "Iteration: 460, Train loss: -0.0451, rewards: 0.5395\n",
      "Iteration: 470, Train loss: -0.0770, rewards: 0.5328\n",
      "Iteration: 480, Train loss: -0.0777, rewards: 0.5162\n",
      "Iteration: 490, Train loss: -0.1441, rewards: 0.5177\n",
      "Iteration: 500, Train loss: -0.1028, rewards: 0.5327\n",
      "Eval:\n",
      "Hits@1: 0.5635, Hits@3: 0.6446, Hits@10: 0.6722, MRR: 0.6053\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.1608, rewards: 0.5409\n",
      "Iteration: 520, Train loss: -0.0844, rewards: 0.5530\n",
      "Iteration: 530, Train loss: -0.0652, rewards: 0.5298\n",
      "Iteration: 540, Train loss: -0.0592, rewards: 0.5706\n",
      "Iteration: 550, Train loss: -0.1441, rewards: 0.5027\n",
      "Iteration: 560, Train loss: -0.1530, rewards: 0.5370\n",
      "Iteration: 570, Train loss: -0.0297, rewards: 0.5288\n",
      "Iteration: 580, Train loss: -0.0148, rewards: 0.5400\n",
      "Iteration: 590, Train loss: -0.0853, rewards: 0.5359\n",
      "Iteration: 600, Train loss: -0.0579, rewards: 0.5155\n",
      "Eval:\n",
      "Hits@1: 0.5801, Hits@3: 0.6483, Hits@10: 0.6648, MRR: 0.6152\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.0563, rewards: 0.5030\n",
      "Iteration: 620, Train loss: -0.0531, rewards: 0.5075\n",
      "Iteration: 630, Train loss: -0.0397, rewards: 0.5311\n",
      "Iteration: 640, Train loss: -0.0578, rewards: 0.5497\n",
      "Iteration: 650, Train loss: -0.0766, rewards: 0.4981\n",
      "Iteration: 660, Train loss: -0.0921, rewards: 0.5633\n",
      "Iteration: 670, Train loss: -0.1227, rewards: 0.5494\n",
      "Iteration: 680, Train loss: -0.1045, rewards: 0.5339\n",
      "Iteration: 690, Train loss: -0.0851, rewards: 0.5142\n",
      "Iteration: 700, Train loss: -0.2199, rewards: 0.5347\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6446, Hits@10: 0.6575, MRR: 0.6145\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.0602, rewards: 0.5273\n",
      "Iteration: 720, Train loss: -0.1229, rewards: 0.5555\n",
      "Iteration: 730, Train loss: -0.0623, rewards: 0.5711\n",
      "Iteration: 740, Train loss: -0.0679, rewards: 0.5872\n",
      "Iteration: 750, Train loss: -0.1087, rewards: 0.5530\n",
      "Iteration: 760, Train loss: -0.1002, rewards: 0.5711\n",
      "Iteration: 770, Train loss: -0.0443, rewards: 0.5095\n",
      "Iteration: 780, Train loss: -0.0590, rewards: 0.5706\n",
      "Iteration: 790, Train loss: -0.0559, rewards: 0.5691\n",
      "Iteration: 800, Train loss: -0.0066, rewards: 0.5306\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6427, Hits@10: 0.6648, MRR: 0.6152\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.0201, rewards: 0.5969\n",
      "Iteration: 820, Train loss: -0.0543, rewards: 0.5948\n",
      "Iteration: 830, Train loss: -0.0697, rewards: 0.5159\n",
      "Iteration: 840, Train loss: -0.1374, rewards: 0.5569\n",
      "Iteration: 850, Train loss: -0.0656, rewards: 0.5216\n",
      "Iteration: 860, Train loss: -0.0647, rewards: 0.5541\n",
      "Iteration: 870, Train loss: -0.0949, rewards: 0.5575\n",
      "Iteration: 880, Train loss: -0.0358, rewards: 0.5112\n",
      "Iteration: 890, Train loss: -0.1050, rewards: 0.5275\n",
      "Iteration: 900, Train loss: -0.0195, rewards: 0.5867\n",
      "Eval:\n",
      "Hits@1: 0.5783, Hits@3: 0.6409, Hits@10: 0.6648, MRR: 0.6113\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.0518, rewards: 0.5478\n",
      "Iteration: 920, Train loss: -0.0238, rewards: 0.5705\n",
      "Iteration: 930, Train loss: -0.0661, rewards: 0.5930\n",
      "Iteration: 940, Train loss: 0.0291, rewards: 0.5384\n",
      "Iteration: 950, Train loss: -0.0761, rewards: 0.5687\n",
      "Iteration: 960, Train loss: -0.0794, rewards: 0.5283\n",
      "Iteration: 970, Train loss: -0.0916, rewards: 0.5469\n",
      "Iteration: 980, Train loss: -0.0654, rewards: 0.5556\n",
      "Iteration: 990, Train loss: -0.0371, rewards: 0.5303\n",
      "Iteration: 1000, Train loss: -0.0139, rewards: 0.5277\n",
      "Eval:\n",
      "Hits@1: 0.5764, Hits@3: 0.6483, Hits@10: 0.6777, MRR: 0.6149\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.1157, rewards: 0.5919\n",
      "Iteration: 1020, Train loss: -0.0335, rewards: 0.5637\n",
      "Iteration: 1030, Train loss: -0.0085, rewards: 0.5698\n",
      "Iteration: 1040, Train loss: -0.0089, rewards: 0.5319\n",
      "Iteration: 1050, Train loss: -0.0102, rewards: 0.5119\n",
      "Iteration: 1060, Train loss: -0.1123, rewards: 0.5464\n",
      "Iteration: 1070, Train loss: -0.0062, rewards: 0.5556\n",
      "Iteration: 1080, Train loss: -0.1054, rewards: 0.6027\n",
      "Iteration: 1090, Train loss: -0.1039, rewards: 0.5373\n",
      "Iteration: 1100, Train loss: -0.0997, rewards: 0.5794\n",
      "Eval:\n",
      "Hits@1: 0.5801, Hits@3: 0.6519, Hits@10: 0.6759, MRR: 0.6176\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.0506, rewards: 0.5711\n",
      "Iteration: 1120, Train loss: -0.0161, rewards: 0.5697\n",
      "Iteration: 1130, Train loss: 0.0279, rewards: 0.5622\n",
      "Iteration: 1140, Train loss: -0.0409, rewards: 0.5586\n",
      "Iteration: 1150, Train loss: 0.0103, rewards: 0.5581\n",
      "Iteration: 1160, Train loss: -0.0575, rewards: 0.5672\n",
      "Iteration: 1170, Train loss: -0.0248, rewards: 0.5806\n",
      "Iteration: 1180, Train loss: 0.0155, rewards: 0.5641\n",
      "Iteration: 1190, Train loss: -0.0063, rewards: 0.5800\n",
      "Iteration: 1200, Train loss: -0.0505, rewards: 0.5834\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6427, Hits@10: 0.6685, MRR: 0.6144\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.0143, rewards: 0.5563\n",
      "Iteration: 1220, Train loss: -0.0535, rewards: 0.5873\n",
      "Iteration: 1230, Train loss: -0.0513, rewards: 0.5787\n",
      "Iteration: 1240, Train loss: -0.0285, rewards: 0.5831\n",
      "Iteration: 1250, Train loss: -0.0102, rewards: 0.5647\n",
      "Iteration: 1260, Train loss: -0.0880, rewards: 0.5983\n",
      "Iteration: 1270, Train loss: 0.0277, rewards: 0.5506\n",
      "Iteration: 1280, Train loss: -0.0938, rewards: 0.5734\n",
      "Iteration: 1290, Train loss: -0.0721, rewards: 0.5663\n",
      "Iteration: 1300, Train loss: 0.0011, rewards: 0.5442\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6483, Hits@10: 0.6685, MRR: 0.6185\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: 0.0324, rewards: 0.5469\n",
      "Iteration: 1320, Train loss: -0.0716, rewards: 0.5578\n",
      "Iteration: 1330, Train loss: -0.0124, rewards: 0.5767\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1340, Train loss: -0.0155, rewards: 0.5558\n",
      "Iteration: 1350, Train loss: -0.0429, rewards: 0.5561\n",
      "Iteration: 1360, Train loss: -0.0500, rewards: 0.5486\n",
      "Iteration: 1370, Train loss: -0.0177, rewards: 0.5398\n",
      "Iteration: 1380, Train loss: -0.0332, rewards: 0.6088\n",
      "Iteration: 1390, Train loss: 0.0053, rewards: 0.5073\n",
      "Iteration: 1400, Train loss: -0.0357, rewards: 0.5805\n",
      "Eval:\n",
      "Hits@1: 0.5820, Hits@3: 0.6483, Hits@10: 0.6703, MRR: 0.6161\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: 0.0158, rewards: 0.5977\n",
      "Iteration: 1420, Train loss: 0.0488, rewards: 0.5527\n",
      "Iteration: 1430, Train loss: -0.0200, rewards: 0.6125\n",
      "Iteration: 1440, Train loss: -0.0072, rewards: 0.5787\n",
      "Iteration: 1450, Train loss: 0.0404, rewards: 0.5486\n",
      "Iteration: 1460, Train loss: -0.0624, rewards: 0.6094\n",
      "Iteration: 1470, Train loss: 0.0489, rewards: 0.5548\n",
      "Iteration: 1480, Train loss: -0.0380, rewards: 0.5948\n",
      "Iteration: 1490, Train loss: 0.0145, rewards: 0.5545\n",
      "Iteration: 1500, Train loss: 0.0488, rewards: 0.5392\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6464, Hits@10: 0.6685, MRR: 0.6197\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.0159, rewards: 0.5586\n",
      "Iteration: 1520, Train loss: 0.0292, rewards: 0.5766\n",
      "Iteration: 1530, Train loss: -0.0071, rewards: 0.6217\n",
      "Iteration: 1540, Train loss: -0.0318, rewards: 0.5531\n",
      "Iteration: 1550, Train loss: 0.0023, rewards: 0.5386\n",
      "Iteration: 1560, Train loss: -0.0679, rewards: 0.5908\n",
      "Iteration: 1570, Train loss: 0.0419, rewards: 0.5642\n",
      "Iteration: 1580, Train loss: 0.0164, rewards: 0.6286\n",
      "Iteration: 1590, Train loss: -0.0466, rewards: 0.5923\n",
      "Iteration: 1600, Train loss: -0.0065, rewards: 0.5498\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6483, Hits@10: 0.6667, MRR: 0.6172\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.0361, rewards: 0.5966\n",
      "Iteration: 1620, Train loss: 0.0287, rewards: 0.5867\n",
      "Iteration: 1630, Train loss: -0.0384, rewards: 0.5528\n",
      "Iteration: 1640, Train loss: 0.0576, rewards: 0.6142\n",
      "Iteration: 1650, Train loss: -0.0174, rewards: 0.6002\n",
      "Iteration: 1660, Train loss: -0.0167, rewards: 0.5487\n",
      "Iteration: 1670, Train loss: 0.0266, rewards: 0.5803\n",
      "Iteration: 1680, Train loss: 0.0443, rewards: 0.5530\n",
      "Iteration: 1690, Train loss: -0.0183, rewards: 0.5625\n",
      "Iteration: 1700, Train loss: -0.0361, rewards: 0.5506\n",
      "Eval:\n",
      "Hits@1: 0.5893, Hits@3: 0.6501, Hits@10: 0.6648, MRR: 0.6205\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.0560, rewards: 0.5364\n",
      "Iteration: 1720, Train loss: -0.0133, rewards: 0.6078\n",
      "Iteration: 1730, Train loss: -0.0766, rewards: 0.5503\n",
      "Iteration: 1740, Train loss: -0.0148, rewards: 0.5739\n",
      "Iteration: 1750, Train loss: -0.0529, rewards: 0.5791\n",
      "Iteration: 1760, Train loss: -0.0789, rewards: 0.5850\n",
      "Iteration: 1770, Train loss: -0.0208, rewards: 0.5653\n",
      "Iteration: 1780, Train loss: -0.0590, rewards: 0.5917\n",
      "Iteration: 1790, Train loss: -0.0856, rewards: 0.6081\n",
      "Iteration: 1800, Train loss: -0.0222, rewards: 0.5952\n",
      "Eval:\n",
      "Hits@1: 0.5801, Hits@3: 0.6446, Hits@10: 0.6667, MRR: 0.6141\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.0203, rewards: 0.5941\n",
      "Iteration: 1820, Train loss: -0.0876, rewards: 0.5661\n",
      "Iteration: 1830, Train loss: -0.0153, rewards: 0.5244\n",
      "Iteration: 1840, Train loss: -0.0603, rewards: 0.5892\n",
      "Iteration: 1850, Train loss: 0.0140, rewards: 0.6138\n",
      "Iteration: 1860, Train loss: -0.0407, rewards: 0.5714\n",
      "Iteration: 1870, Train loss: -0.0583, rewards: 0.5878\n",
      "Iteration: 1880, Train loss: -0.0962, rewards: 0.5803\n",
      "Iteration: 1890, Train loss: -0.0359, rewards: 0.5830\n",
      "Iteration: 1900, Train loss: 0.0241, rewards: 0.5695\n",
      "Eval:\n",
      "Hits@1: 0.5875, Hits@3: 0.6538, Hits@10: 0.6703, MRR: 0.6209\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.0497, rewards: 0.5563\n",
      "Iteration: 1920, Train loss: -0.0909, rewards: 0.5906\n",
      "Iteration: 1930, Train loss: -0.0155, rewards: 0.5344\n",
      "Iteration: 1940, Train loss: -0.0353, rewards: 0.5641\n",
      "Iteration: 1950, Train loss: -0.0364, rewards: 0.5942\n",
      "Iteration: 1960, Train loss: -0.0497, rewards: 0.5833\n",
      "Iteration: 1970, Train loss: -0.0246, rewards: 0.5827\n",
      "Iteration: 1980, Train loss: -0.0233, rewards: 0.5975\n",
      "Iteration: 1990, Train loss: -0.0048, rewards: 0.5647\n",
      "Iteration: 2000, Train loss: -0.0421, rewards: 0.5831\n",
      "Eval:\n",
      "Hits@1: 0.5856, Hits@3: 0.6464, Hits@10: 0.6648, MRR: 0.6192\n",
      "------------------------------------------------------------\n",
      "Iteration: 2010, Train loss: 0.0355, rewards: 0.5677\n",
      "Iteration: 2020, Train loss: -0.0213, rewards: 0.5906\n",
      "Iteration: 2030, Train loss: 0.0168, rewards: 0.5623\n",
      "Iteration: 2040, Train loss: -0.0399, rewards: 0.5778\n",
      "Iteration: 2050, Train loss: -0.0086, rewards: 0.5936\n",
      "Iteration: 2060, Train loss: 0.0020, rewards: 0.6164\n",
      "Iteration: 2070, Train loss: -0.0320, rewards: 0.6144\n",
      "Iteration: 2080, Train loss: -0.0147, rewards: 0.4909\n",
      "Iteration: 2090, Train loss: -0.0584, rewards: 0.5672\n",
      "Iteration: 2100, Train loss: -0.0410, rewards: 0.5647\n",
      "Eval:\n",
      "Hits@1: 0.5709, Hits@3: 0.6390, Hits@10: 0.6593, MRR: 0.6088\n",
      "------------------------------------------------------------\n",
      "Iteration: 2110, Train loss: -0.0550, rewards: 0.6214\n",
      "Iteration: 2120, Train loss: 0.0115, rewards: 0.5297\n",
      "Iteration: 2130, Train loss: 0.0526, rewards: 0.5333\n",
      "Iteration: 2140, Train loss: -0.0033, rewards: 0.5697\n",
      "Iteration: 2150, Train loss: -0.0038, rewards: 0.6161\n",
      "Iteration: 2160, Train loss: -0.0661, rewards: 0.6027\n",
      "Iteration: 2170, Train loss: -0.0032, rewards: 0.5859\n",
      "Iteration: 2180, Train loss: -0.0225, rewards: 0.5556\n",
      "Iteration: 2190, Train loss: -0.0621, rewards: 0.5573\n",
      "Iteration: 2200, Train loss: -0.0706, rewards: 0.6009\n",
      "Eval:\n",
      "Hits@1: 0.5838, Hits@3: 0.6483, Hits@10: 0.6575, MRR: 0.6159\n",
      "------------------------------------------------------------\n",
      "Iteration: 2210, Train loss: -0.0113, rewards: 0.5523\n",
      "Iteration: 2220, Train loss: -0.0173, rewards: 0.6177\n",
      "Iteration: 2230, Train loss: -0.0051, rewards: 0.5534\n",
      "Iteration: 2240, Train loss: -0.0490, rewards: 0.5705\n",
      "Iteration: 2250, Train loss: -0.0373, rewards: 0.5875\n",
      "Iteration: 2260, Train loss: -0.0916, rewards: 0.6106\n",
      "Iteration: 2270, Train loss: -0.0059, rewards: 0.5420\n",
      "Iteration: 2280, Train loss: 0.0086, rewards: 0.5850\n",
      "Iteration: 2290, Train loss: 0.0892, rewards: 0.6070\n",
      "Iteration: 2300, Train loss: -0.0308, rewards: 0.5881\n",
      "Eval:\n",
      "Hits@1: 0.5783, Hits@3: 0.6390, Hits@10: 0.6611, MRR: 0.6120\n",
      "------------------------------------------------------------\n",
      "Iteration: 2310, Train loss: -0.0238, rewards: 0.6014\n",
      "Iteration: 2320, Train loss: -0.0711, rewards: 0.5709\n",
      "Iteration: 2330, Train loss: -0.0504, rewards: 0.6241\n",
      "Iteration: 2340, Train loss: -0.0284, rewards: 0.5719\n",
      "Iteration: 2350, Train loss: -0.0592, rewards: 0.6080\n",
      "Iteration: 2360, Train loss: -0.0059, rewards: 0.5297\n",
      "Iteration: 2370, Train loss: -0.0246, rewards: 0.6219\n",
      "Iteration: 2380, Train loss: 0.0012, rewards: 0.6012\n",
      "Iteration: 2390, Train loss: -0.0239, rewards: 0.6466\n",
      "Iteration: 2400, Train loss: -0.0117, rewards: 0.5708\n",
      "Eval:\n",
      "Hits@1: 0.5783, Hits@3: 0.6427, Hits@10: 0.6667, MRR: 0.6140\n",
      "------------------------------------------------------------\n",
      "Iteration: 2410, Train loss: -0.0266, rewards: 0.5733\n",
      "Iteration: 2420, Train loss: -0.0214, rewards: 0.6219\n",
      "Iteration: 2430, Train loss: -0.0175, rewards: 0.5916\n",
      "Iteration: 2440, Train loss: -0.0119, rewards: 0.6031\n",
      "Iteration: 2450, Train loss: -0.0662, rewards: 0.5773\n",
      "Iteration: 2460, Train loss: -0.0681, rewards: 0.5988\n",
      "Iteration: 2470, Train loss: -0.0693, rewards: 0.5917\n",
      "Iteration: 2480, Train loss: 0.0174, rewards: 0.5616\n",
      "Iteration: 2490, Train loss: -0.0138, rewards: 0.6134\n",
      "Iteration: 2500, Train loss: -0.0407, rewards: 0.6117\n",
      "Eval:\n",
      "Hits@1: 0.5801, Hits@3: 0.6464, Hits@10: 0.6648, MRR: 0.6141\n",
      "------------------------------------------------------------\n",
      "Iteration: 2510, Train loss: 0.0466, rewards: 0.6067\n",
      "Iteration: 2520, Train loss: -0.0051, rewards: 0.5938\n",
      "Iteration: 2530, Train loss: 0.0149, rewards: 0.5981\n",
      "Iteration: 2540, Train loss: 0.0110, rewards: 0.6356\n",
      "Iteration: 2550, Train loss: 0.0253, rewards: 0.5339\n",
      "Iteration: 2560, Train loss: 0.0291, rewards: 0.5509\n",
      "Iteration: 2570, Train loss: -0.0102, rewards: 0.5927\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 2580, Train loss: -0.0121, rewards: 0.5900\n",
      "Iteration: 2590, Train loss: -0.0438, rewards: 0.6288\n",
      "Iteration: 2600, Train loss: -0.0285, rewards: 0.5966\n",
      "Eval:\n",
      "Hits@1: 0.5580, Hits@3: 0.6317, Hits@10: 0.6556, MRR: 0.5969\n",
      "------------------------------------------------------------\n",
      "Iteration: 2610, Train loss: -0.0691, rewards: 0.6031\n",
      "Iteration: 2620, Train loss: -0.0489, rewards: 0.6000\n",
      "Iteration: 2630, Train loss: 0.0489, rewards: 0.6431\n",
      "Iteration: 2640, Train loss: 0.0341, rewards: 0.6106\n",
      "Iteration: 2650, Train loss: -0.0577, rewards: 0.6042\n",
      "Iteration: 2660, Train loss: 0.0786, rewards: 0.6033\n",
      "Iteration: 2670, Train loss: -0.0061, rewards: 0.6267\n",
      "Iteration: 2680, Train loss: -0.0188, rewards: 0.6222\n",
      "Iteration: 2690, Train loss: -0.0360, rewards: 0.6039\n",
      "Iteration: 2700, Train loss: 0.0488, rewards: 0.5961\n",
      "Eval:\n",
      "Hits@1: 0.5635, Hits@3: 0.6335, Hits@10: 0.6648, MRR: 0.6018\n",
      "------------------------------------------------------------\n",
      "Iteration: 2710, Train loss: 0.0023, rewards: 0.5881\n",
      "Iteration: 2720, Train loss: -0.0588, rewards: 0.6078\n",
      "Iteration: 2730, Train loss: 0.0230, rewards: 0.6072\n",
      "Iteration: 2740, Train loss: -0.0132, rewards: 0.5561\n",
      "Iteration: 2750, Train loss: -0.0226, rewards: 0.5694\n",
      "Iteration: 2760, Train loss: -0.0229, rewards: 0.6494\n",
      "Iteration: 2770, Train loss: -0.0076, rewards: 0.6156\n",
      "Iteration: 2780, Train loss: 0.0444, rewards: 0.5989\n",
      "Iteration: 2790, Train loss: -0.0117, rewards: 0.6439\n",
      "Iteration: 2800, Train loss: 0.0290, rewards: 0.6080\n",
      "Eval:\n"
     ]
    }
   ],
   "source": [
    "from model.ours import *\n",
    "\n",
    "results = {}\n",
    "for layer in [1, 2]:\n",
    "    for bl in [0.1, 0.08, 0.12, 0.05, 0.15, 0.02, 0.18]:\n",
    "        for bs in [32, 64, 128]:\n",
    "            for lr in [1e-4, 5e-5, 5e-4, 1e-3, 1e-5]:\n",
    "                params = set_params(layer, bs, bl, bl, lr)\n",
    "                name = f'{layer}-{bs}-{bl}-{bl}-{lr}'\n",
    "\n",
    "                trainer = Trainer(params)\n",
    "                trainer.train()\n",
    "                torch.cuda.empty_cache()\n",
    "\n",
    "                trainer.agent.load_state_dict(torch.load(params['model_dir'] + 'agent.ckpt'))\n",
    "                trainer.agent.eval()\n",
    "                trainer.test_environment = trainer.test_test_environment\n",
    "                tmp = trainer.test(beam=True, print_paths=False, save_model=False)\n",
    "\n",
    "                print(name)\n",
    "                print(tmp)\n",
    "                print('-------------')\n",
    "                results[name] = tmp\n",
    "\n",
    "                with open(params['output_dir'] + 'results_table.pk5', 'wb') as f:\n",
    "                    pickle5.dump(results, f)\n",
    "                    \n",
    "                del trainer\n",
    "                gc.collect()\n",
    "                torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
