{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n",
    "\n",
    "from model.trainer import Trainer\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2e28869",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/cwong/miniconda3/envs/tracy/lib/python3.9/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from tqdm.notebook import tqdm\n",
    "import json\n",
    "import time\n",
    "import os\n",
    "import logging\n",
    "import numpy as np\n",
    "from model.agent import Agent\n",
    "from model.options import read_options\n",
    "from model.environment import env\n",
    "import codecs\n",
    "from collections import defaultdict\n",
    "import gc\n",
    "import resource\n",
    "import sys\n",
    "from model.baseline import ReactiveBaseline\n",
    "from scipy.special import logsumexp as lse\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam, SGD, AdamW\n",
    "from copy import deepcopy\n",
    "from model.nell_eval import nell_eval\n",
    "\n",
    "def get_logger(output_dir):\n",
    "    filename=output_dir +'train'\n",
    "    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter\n",
    "    logger = getLogger(__name__)\n",
    "    logger.setLevel(INFO)\n",
    "    handler1 = StreamHandler()\n",
    "    handler1.setFormatter(Formatter(\"%(message)s\"))\n",
    "    handler2 = FileHandler(filename=f\"{filename}.log\")\n",
    "    handler2.setFormatter(Formatter(\"%(message)s\"))\n",
    "    logger.addHandler(handler1)\n",
    "    logger.addHandler(handler2)\n",
    "    return logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fc64eafe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.utils as utils\n",
    "from torch_scatter import scatter\n",
    "\n",
    "class node_aggregation(nn.Module):\n",
    "    def __init__(self, m, embedding_size, hidden_size, layer = 2):\n",
    "        super(node_aggregation, self).__init__()\n",
    "        \n",
    "        self.m = m\n",
    "        self.embedding_size = embedding_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.layer = layer\n",
    "        \n",
    "        self.node_encoder = nn.Linear(2*embedding_size, 2*embedding_size)\n",
    "        self.link_encoder = nn.Linear(2*embedding_size, 2*embedding_size)\n",
    "        self.value_encoder = nn.Linear(2*2*embedding_size, 2*embedding_size)\n",
    "        \n",
    "        self.mlp = nn.Sequential(nn.Linear(2*embedding_size, 2*2*embedding_size), nn.Mish(),\n",
    "                                 nn.Linear(2*2*embedding_size, 2*embedding_size))\n",
    "        self.mlp_norm = nn.LayerNorm(2*embedding_size)\n",
    "        self.norm = nn.LayerNorm(2*embedding_size)\n",
    "\n",
    "    def forward(self, H, L, T, X, Z):\n",
    "        h_x = self.node_encoder(X[H])\n",
    "        l_x = self.link_encoder(Z[L])\n",
    "        v = self.value_encoder(torch.cat([X[T], Z[L]], -1))\n",
    "\n",
    "        score = (l_x*h_x).sum(-1)/np.sqrt(L.shape[-1])\n",
    "        norm = scatter(torch.exp(score), H.long(), dim = 0, reduce = 'sum')\n",
    "        norm = norm[H.long()]\n",
    "        att = torch.exp(score)/norm\n",
    "\n",
    "        x_ = scatter(v*att.unsqueeze(-1), H.long(), dim = 0, reduce = 'sum')\n",
    "        \n",
    "        x = X\n",
    "        x[:len(x_)] += x_\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        x = self.mlp_norm(x + self.mlp(x))\n",
    "        return x\n",
    "    \n",
    "class relation_aggregation(nn.Module):\n",
    "    def __init__(self, m, embedding_size, hidden_size, layer = 2):\n",
    "        super(relation_aggregation, self).__init__()\n",
    "        \n",
    "        self.m = m\n",
    "        self.embedding_size = embedding_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.layer = layer\n",
    "        \n",
    "        self.node_encoder = nn.Sequential(nn.Linear(2*embedding_size, 2*embedding_size), nn.Mish(),\n",
    "                                          nn.Linear(2*embedding_size, 2*embedding_size))\n",
    "        self.mlp = nn.Sequential(nn.Linear(2*embedding_size, 2*2*embedding_size), nn.Mish(),\n",
    "                                 nn.Linear(2*2*embedding_size, 2*embedding_size))\n",
    "        self.mlp_norm = nn.LayerNorm(2*embedding_size)\n",
    "        self.norm = nn.LayerNorm(2*embedding_size)\n",
    "\n",
    "    def forward(self, H, L, T, X, Z):\n",
    "        h_x = self.node_encoder(X[H])\n",
    "        z_ = scatter(h_x, L.long(), dim = 0, reduce = 'mean')\n",
    "        \n",
    "        z = Z\n",
    "        z[:len(z_)] += z_\n",
    "        z = self.norm(z)\n",
    "        \n",
    "        z = self.mlp_norm(z + self.mlp(z))\n",
    "        return z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "64ac7797",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.utils as utils\n",
    "\n",
    "class Policy_step(nn.Module):\n",
    "    def __init__(self, m, embedding_size, hidden_size):\n",
    "        super(Policy_step, self).__init__()\n",
    "\n",
    "        self.batch_norm = nn.BatchNorm1d(m * hidden_size)\n",
    "        self.lstm_cell = nn.LSTMCell(input_size= (m * embedding_size), \n",
    "                                     hidden_size= m * hidden_size)\n",
    "        self.l1 = nn.Linear(m * embedding_size,  m * embedding_size)\n",
    "        self.l2 = nn.Linear( m * hidden_size, m * hidden_size)\n",
    "        self.l3 = nn.Linear( m * hidden_size, m * hidden_size)\n",
    "\n",
    "    def forward(self, prev_gnn, prev_action, prev_state):\n",
    "        prev_action = torch.relu(self.l1(prev_action))\n",
    "        \n",
    "        #prev_input = torch.cat([prev_gnn, prev_action], -1)\n",
    "        prev_input = prev_action\n",
    "        output, ch = self.lstm_cell(prev_input, prev_state)\n",
    "        output = torch.relu(self.l2(output))\n",
    "        ch = torch.relu(self.l3(ch))\n",
    "\n",
    "        ch = torch.cat([output.unsqueeze(0).unsqueeze(0), ch.unsqueeze(0).unsqueeze(0)], dim=1)\n",
    "\n",
    "        return output, ch\n",
    "\n",
    "class Policy_mlp(nn.Module):\n",
    "    def __init__(self, hidden_size, m, embedding_size):\n",
    "        super(Policy_mlp, self).__init__()\n",
    "\n",
    "        self.hidden_size = hidden_size\n",
    "        self.m = m\n",
    "        self.embedding_size = embedding_size\n",
    "        self.mlp_l1 = nn.Linear(m * self.hidden_size + 2 * 2* embedding_size, m * self.hidden_size, bias=True)\n",
    "        self.mlp_l2 = nn.Linear(m * self.hidden_size, (m * self.embedding_size * 2), bias=True)\n",
    "\n",
    "    def forward(self, state_query):\n",
    "        # state_query = state_query.float()\n",
    "        hidden = torch.relu(self.mlp_l1(state_query))\n",
    "        output = torch.relu(self.mlp_l2(hidden))\n",
    "        return output\n",
    "\n",
    "class Agent(nn.Module):\n",
    "\n",
    "    def __init__(self, params):\n",
    "        super(Agent, self).__init__()\n",
    "        self.action_vocab_size = len(params['relation_vocab'])\n",
    "        self.entity_vocab_size = len(params['entity_vocab'])\n",
    "        self.embedding_size = params['embedding_size']\n",
    "        self.hidden_size = params['hidden_size']\n",
    "        self.ePAD = params['entity_vocab']['PAD']\n",
    "        self.rPAD = params['relation_vocab']['PAD']\n",
    "        self.use_entity_embeddings = params['use_entity_embeddings']\n",
    "        self.train_entity_embeddings = params['train_entity_embeddings']\n",
    "        self.train_relation_embeddings = params['train_relation_embeddings']\n",
    "        self.device = params['device']\n",
    "\n",
    "        if self.use_entity_embeddings:\n",
    "            if self.train_entity_embeddings:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size)\n",
    "            else:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                    False)\n",
    "            torch.nn.init.xavier_uniform_(self.entity_embedding.weight)\n",
    "        else:\n",
    "            if self.train_entity_embeddings:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size)\n",
    "            else:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                    False)\n",
    "            torch.nn.init.constant_(self.entity_embedding.weight, 0.0)\n",
    "\n",
    "        if self.train_relation_embeddings:\n",
    "            self.relation_embedding = nn.Embedding(self.action_vocab_size, 2 * self.embedding_size)\n",
    "        else:\n",
    "            self.relation_embedding = nn.Embedding(self.action_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                False)\n",
    "        torch.nn.init.xavier_uniform_(self.relation_embedding.weight)\n",
    "\n",
    "        # self.relation_embedding = params['pretrained_embeddings_relation']\n",
    "\n",
    "        # self.train_entities = params['train_entity_embeddings']\n",
    "        # self.train_relations = params['train_relation_embeddings']\n",
    "\n",
    "        self.num_rollouts = params['num_rollouts']\n",
    "        self.test_rollouts = params['test_rollouts']\n",
    "        self.LSTM_Layers = params['LSTM_layers']\n",
    "        self.batch_size = params['batch_size'] * params['num_rollouts']\n",
    "        self.dummy_start_label = (torch.ones(self.batch_size) * params['relation_vocab']['DUMMY_START_RELATION']).long()\n",
    "        # print(self.dummy_start_label.size())\n",
    "        self.entity_embedding_size = self.embedding_size\n",
    "        self.initial_gnn_state = nn.Parameter(torch.zeros((1, self.embedding_size*2*2)))\n",
    "        torch.nn.init.xavier_uniform_(self.initial_gnn_state)\n",
    "\n",
    "        if self.use_entity_embeddings:\n",
    "            self.m = 4\n",
    "        else:\n",
    "            self.m = 2\n",
    "\n",
    "        self.policy_step = Policy_step(m=self.m, embedding_size=self.embedding_size, hidden_size=self.hidden_size).to(self.device)\n",
    "        self.policy_mlp = Policy_mlp(self.hidden_size, self.m, self.embedding_size).to(self.device)\n",
    "        \n",
    "        self.node_conv = nn.ModuleList()\n",
    "        self.rel_conv = nn.ModuleList()\n",
    "        for i in range(1):\n",
    "            self.node_conv.append(node_aggregation(m=self.m, \n",
    "                                                   embedding_size=self.embedding_size, \n",
    "                                                   hidden_size=self.hidden_size).to(self.device))\n",
    "            self.rel_conv.append(relation_aggregation(m=self.m, \n",
    "                                                      embedding_size=self.embedding_size, \n",
    "                                                      hidden_size=self.hidden_size).to(self.device))\n",
    "        \n",
    "        self.gate1_linear = nn.Linear(2*self.hidden_size, 3*2*self.hidden_size)\n",
    "        self.gate2_linear = nn.Linear(2*self.hidden_size, 3*2*self.hidden_size)\n",
    "        \n",
    "        self.state_encoder = nn.Sequential(nn.Linear(self.m*self.hidden_size*2 + self.embedding_size*2,\n",
    "                                           self.embedding_size*2*2), nn.Mish(), \n",
    "                                           nn.LayerNorm(self.embedding_size*2*2))\n",
    "\n",
    "\n",
    "    def get_mem_shape(self):\n",
    "        return (self.LSTM_Layers, 2, None, self.m * self.hidden_size)\n",
    "\n",
    "\n",
    "    def action_encoder(self, next_relations, next_entities):\n",
    "        # relation_embedding = self.relation_embedding[next_relations.cpu().numpy()]\n",
    "        # entity_embedding = self.entity_embedding[next_entities.cpu().numpy()]\n",
    "        relation_embedding = self.rel_embed[next_relations]\n",
    "        entity_embedding = self.ent_embed[next_entities]\n",
    "\n",
    "        if self.use_entity_embeddings:\n",
    "            action_embedding = torch.cat([relation_embedding, entity_embedding], dim=-1)\n",
    "        else:\n",
    "            action_embedding = relation_embedding\n",
    "\n",
    "        return action_embedding\n",
    "    \n",
    "    def neighbour_aggregation(self, query_relation, prev_state, next_neighbors):\n",
    "        link = next_neighbors[:, :, :, 1]\n",
    "        tail = next_neighbors[:, :, :, 0]\n",
    "        mask = (link != 0).float()\n",
    "\n",
    "        t_embed = self.entity_embedding(tail)\n",
    "        r_embed = self.relation_embedding(link)\n",
    "\n",
    "        lstm_state = torch.cat([prev_state[0], prev_state[1]], -1)\n",
    "        query_embedding = self.relation_embedding(query_relation.long())\n",
    "        state = self.state_encoder(torch.cat([lstm_state, query_embedding], -1))\n",
    "        neighbor_embedding = torch.cat([t_embed, r_embed], -1)\n",
    "\n",
    "        att = (state.unsqueeze(1).unsqueeze(1)*neighbor_embedding).sum(-1)/np.sqrt(state.shape[-1])\n",
    "        att = F.softmax(att - (1 - mask)*1e8, 2)\n",
    "        update_embedding = (neighbor_embedding*att.unsqueeze(-1)).sum(2)\n",
    "        return update_embedding\n",
    "\n",
    "    def step(self, next_relations, next_entities, prev_state, prev_relation, query_embedding, current_entities,\n",
    "             head, link, tail, next_neighbors, prev_gnn_state):\n",
    "        \n",
    "        n_ent = len(self.entity_embedding.weight)\n",
    "        n_rel = len(self.relation_embedding.weight)\n",
    "\n",
    "        self.ent_embed = self.entity_embedding(torch.arange(n_ent).to(head.device))\n",
    "        self.rel_embed = self.relation_embedding(torch.arange(n_rel).to(link.device))\n",
    "\n",
    "        for i in range(1):\n",
    "            self.ent_embed = self.node_conv[i](head, link, tail, \n",
    "                                                    self.ent_embed, self.rel_embed)\n",
    "            self.rel_embed = self.rel_conv[i](head, link, tail, \n",
    "                                                    self.ent_embed, self.rel_embed)\n",
    "\n",
    "        prev_action_embedding = self.action_encoder(prev_relation, current_entities) # (original batch_size * num_rollout, 4*self.embedding_size)\n",
    "\n",
    "        prev_state = torch.unbind(prev_state, dim=1)\n",
    "        prev_state = [prev_state[0].squeeze(0), prev_state[1].squeeze(0)]\n",
    "\n",
    "        new_prev_state = list()\n",
    "\n",
    "        output, new_state = self.policy_step(prev_gnn_state, prev_action_embedding, prev_state)\n",
    "\n",
    "        prev_entity = self.ent_embed[current_entities]\n",
    "        if self.use_entity_embeddings:\n",
    "            state = torch.cat([output, prev_entity], dim=-1)\n",
    "        else:\n",
    "            state = output\n",
    "\n",
    "        candidate_action_embeddings = self.action_encoder(next_relations, next_entities)\n",
    "        gnn_embedding = self.neighbour_aggregation(query_embedding, prev_state, next_neighbors)\n",
    "        candidate_action_embeddings = torch.cat([candidate_action_embeddings, gnn_embedding], -1)\n",
    "\n",
    "        query_embedding = self.rel_embed[query_embedding]\n",
    "        state_query_concat = torch.cat([state, query_embedding], dim=-1)\n",
    "\n",
    "        # MLP for policy#\n",
    "\n",
    "        output = self.policy_mlp(state_query_concat)\n",
    "        # print(output.size())\n",
    "        output_expanded = torch.unsqueeze(output, dim=1)  # [original batch_size * num_rollout, 1, 2D], D=self.hidden_size\n",
    "        # print(output_expanded.size(), candidate_action_embeddings.size())\n",
    "        prelim_scores = torch.sum(candidate_action_embeddings * output_expanded, dim=2)\n",
    "\n",
    "        # Masking PAD actions\n",
    "\n",
    "        comparison_tensor = torch.ones_like(next_relations).int() * self.rPAD  # matrix to compare\n",
    "        mask = next_relations == comparison_tensor  # The mask\n",
    "        dummy_scores = torch.ones_like(prelim_scores) * -99999.0  # the base matrix to choose from if dummy relation\n",
    "        scores = torch.where(mask, dummy_scores, prelim_scores)  # [original batch_size * num_rollout, max_num_actions]\n",
    "\n",
    "        # 4 sample action\n",
    "        action = torch.distributions.categorical.Categorical(logits=scores) # [original batch_size * num_rollout, 1]\n",
    "        label_action = action.sample() # [original batch_size * num_rollout,]\n",
    "\n",
    "        # loss\n",
    "        # 5a.\n",
    "        loss = torch.nn.CrossEntropyLoss(reduce=False)(scores, label_action)\n",
    "\n",
    "        # 6. Map back to true id\n",
    "        chosen_relation = next_relations[torch.arange(len(label_action)), label_action]\n",
    "\n",
    "        return loss, new_state, F.log_softmax(scores), label_action, chosen_relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a5a4c2e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import time\n",
    "import os\n",
    "import logging\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "\n",
    "#from model.agent import Agent\n",
    "from model.options import read_options\n",
    "from model.environment import env\n",
    "import codecs\n",
    "from collections import defaultdict\n",
    "import gc\n",
    "import resource\n",
    "import sys\n",
    "from model.baseline import ReactiveBaseline\n",
    "from scipy.special import logsumexp as lse\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from model.nell_eval import nell_eval\n",
    "\n",
    "logger = logging.getLogger()\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)\n",
    "\n",
    "\n",
    "class Trainer(object):\n",
    "    def __init__(self, params):\n",
    "\n",
    "        # transfer parameters to self\n",
    "        for key, val in params.items(): setattr(self, key, val);\n",
    "        self.device = params['device']\n",
    "        print(self.device)\n",
    "        self.agent = Agent(params).to(self.device)\n",
    "        #self.c_agent = ClusterAgent(params).to(self.device)\n",
    "        self.model_dir = params['model_dir']\n",
    "        self.save_path = self.model_dir + \"model\" + '.ckpt'\n",
    "        self.train_environment = env(params, 'train')\n",
    "        self.dev_test_environment = env(params, 'dev')\n",
    "        self.test_test_environment = env(params, 'test')\n",
    "        self.test_environment = self.dev_test_environment\n",
    "        self.rev_relation_vocab = self.train_environment.grapher.rev_relation_vocab\n",
    "        self.rev_entity_vocab = self.train_environment.grapher.rev_entity_vocab\n",
    "        #self.rev_cluster_relation_vocab = self.train_environment.cluster_grapher.rev_cluster_relation_vocab\n",
    "        #self.rev_cluster_vocab = self.train_environment.cluster_grapher.rev_cluster_vocab\n",
    "\n",
    "        self.max_hits_at_10 = 0\n",
    "        self.ePAD = self.entity_vocab['PAD']\n",
    "        self.rPAD = self.relation_vocab['PAD']\n",
    "        self.decaying_beta_init = self.beta\n",
    "        # optimize\n",
    "        self.baseline = ReactiveBaseline(params, self.Lambda)\n",
    "\n",
    "        self.decay_batch = None\n",
    "        self.gamma = params['gamma']\n",
    "        self.grad_clip_norm = params['grad_clip_norm']\n",
    "        self.eval_every = params['eval_every']\n",
    "        self.total_iterations = params['total_iterations']\n",
    "        self.learning_rate = params['learning_rate']\n",
    "        self.pool = params['pool']\n",
    "        self.output_dir = params['output_dir']\n",
    "        self.params = params\n",
    "        \n",
    "        self.positive_reward_rates = []\n",
    "        self.optimizer = optim.Adam(list(self.agent.parameters()),\n",
    "                                    lr=self.learning_rate)\n",
    "        self.two_embeds_sim_criterion = torch.nn.KLDivLoss()\n",
    "\n",
    "    def calc_reinforce_loss(self, all_loss, all_logits, cum_discounted_reward, decaying_beta, baseline):\n",
    "\n",
    "        loss = torch.stack(all_loss, dim=1)  # [original batch_size * num_rollout, T]\n",
    "        base_value = baseline.get_baseline_value()\n",
    "\n",
    "        # multiply with rewards\n",
    "        final_reward = cum_discounted_reward - base_value\n",
    "        reward_mean = torch.mean(final_reward)\n",
    "\n",
    "        # Constant added for numerical stability\n",
    "        reward_std = torch.std(final_reward) + 1e-6\n",
    "        final_reward = torch.div(final_reward - reward_mean, reward_std)\n",
    "\n",
    "        loss = torch.mul(loss, final_reward)  # [original batch_size * num_rollout, T]\n",
    "\n",
    "        entropy_loss = decaying_beta * self.entropy_reg_loss(all_logits)\n",
    "\n",
    "        total_loss = torch.mean(loss) - entropy_loss  # scalar\n",
    "\n",
    "        return total_loss\n",
    "    \n",
    "    def entropy_reg_loss(self, all_logits):  # control diversity\n",
    "        all_logits = torch.stack(all_logits, dim=2)  # [original batch_size * num_rollout, max_num_actions, T]\n",
    "        entropy_loss = - torch.mean(torch.sum(torch.mul(torch.exp(all_logits), all_logits), dim=1))  # scalar\n",
    "        return entropy_loss\n",
    "\n",
    "    def calc_cum_discounted_reward(self, rewards):\n",
    "\n",
    "        running_add = torch.zeros([rewards.size(0)]).to(self.device)  # [original batch_size * num_rollout]\n",
    "        cum_disc_reward = torch.zeros([rewards.size(0), self.path_length]).to(\n",
    "            self.device)  # [original batch_size * num_rollout, T]\n",
    "        cum_disc_reward[:,\n",
    "        self.path_length - 1] = rewards  # set the last time step to the reward received at the last state\n",
    "        for t in reversed(range(self.path_length)):\n",
    "            running_add = self.gamma * running_add + cum_disc_reward[:, t]\n",
    "            cum_disc_reward[:, t] = running_add\n",
    "        return cum_disc_reward\n",
    "\n",
    "    def calc_cum_discounted_reward_credit(self, entity_rewards):\n",
    "\n",
    "        num_instances = entity_rewards.size(0)\n",
    "        running_add = torch.zeros([num_instances]).to(self.device)  # [original batch_size * num_rollout]\n",
    "        cum_disc_reward = torch.zeros([num_instances, self.path_length]).to(\n",
    "            self.device)  # [original batch_size * num_rollout, T]\n",
    "        cum_disc_reward[:,\n",
    "        self.path_length - 1] = entity_rewards  # set the last time step to the reward received at the last state\n",
    "\n",
    "        for t in reversed(range(1, self.path_length)):\n",
    "            running_add = self.gamma * running_add + cum_disc_reward[:, t] # approx_credits[t].to(self.device) * cluster_rewards\n",
    "            cum_disc_reward[:, t-1] = running_add\n",
    "\n",
    "        return cum_disc_reward\n",
    "\n",
    "    def get_graph(self, entity_episode, mode = 'train'):\n",
    "        if mode == 'train':\n",
    "            graph = deepcopy(self.train_environment.grapher)\n",
    "        else:\n",
    "            graph = deepcopy(self.test_environment.grapher)\n",
    "        \n",
    "        if mode == 'train':\n",
    "            ent_match = graph.array_store[entity_episode.start_entities,:,0] == np.expand_dims(entity_episode.end_entities, -1)\n",
    "            rel_match = graph.array_store[entity_episode.start_entities,:,1] == np.expand_dims(entity_episode.query_relation, -1)\n",
    "\n",
    "            tmp = graph.array_store[entity_episode.start_entities, :, :]\n",
    "            tmp[rel_match&ent_match] = 0\n",
    "            graph.array_store[entity_episode.start_entities, :, :] = tmp\n",
    "\n",
    "        head = torch.repeat_interleave(torch.arange(len(graph.rev_entity_vocab)), self.params['max_num_actions'])\n",
    "        tail = torch.LongTensor(graph.array_store[:, :, 0].reshape(-1))\n",
    "        link = torch.LongTensor(graph.array_store[:, :, 1].reshape(-1))\n",
    "\n",
    "        head = head[(link > 0)].to(self.device)\n",
    "        tail = tail[(link > 0)].to(self.device)\n",
    "        link = link[(link > 0)].to(self.device)\n",
    "        return head, link, tail\n",
    "    \n",
    "    def train(self):\n",
    "        train_loss = []\n",
    "        train_reward = []\n",
    "\n",
    "        start_time = time.time()\n",
    "        self.batch_counter = 0\n",
    "        current_decay = self.decaying_beta_init\n",
    "        current_decay_count = 0\n",
    "\n",
    "        print('Agent start learning ...')\n",
    "        for entity_episode in self.train_environment.get_episodes():\n",
    "\n",
    "            self.batch_counter += 1\n",
    "\n",
    "            current_decay_count += 1\n",
    "            if current_decay_count == self.decay_batch:\n",
    "                current_decay *= self.decay_rate\n",
    "                current_decay_count = 0\n",
    "\n",
    "            # get initial state for entity agent\n",
    "            head, link, tail = self.get_graph(entity_episode, mode = 'train')\n",
    "            entity_state_emb = torch.zeros(1, 2, self.batch_size * self.num_rollouts,\n",
    "                                           self.agent.m * self.hidden_size).to(self.device)\n",
    "            entity_state = entity_episode.get_state()\n",
    "            next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(\n",
    "                self.device)  # original batch_size * num_rollout, max_num_actions\n",
    "            next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "\n",
    "            # range_arr = torch.arange(self.batch_size * self.num_rollouts).to(self.device)\n",
    "            prev_relation = self.agent.dummy_start_label.to(self.device)  # original batch_size * num_rollout, 1-D, (1...)\n",
    "\n",
    "            query_relation = entity_episode.get_query_relation()\n",
    "            query_relation = torch.tensor(query_relation).long().to(self.device)\n",
    "            current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "            prev_gnn_state = self.agent.initial_gnn_state.repeat(len(current_entities), 1)\n",
    "            #prev_entities = current_entities.clone()\n",
    "\n",
    "            all_losses = []\n",
    "            all_logits = []\n",
    "            all_action_id = []\n",
    "            path = [current_entities]\n",
    "\n",
    "            for i in range(self.path_length):\n",
    "                next_neighbors = deepcopy(self.train_environment.grapher.array_store)\n",
    "                in_sample = next_neighbors[entity_episode.start_entities]\n",
    "                end_match = in_sample[:, :, 0] == np.expand_dims(entity_episode.end_entities, -1)\n",
    "                rel_match = in_sample[:, :, 1] == np.expand_dims(entity_episode.query_relation, -1)\n",
    "                in_sample[end_match&rel_match] = 0\n",
    "                next_neighbors[entity_episode.start_entities] = in_sample\n",
    "\n",
    "                next_neighbors = next_neighbors[next_possible_entities.cpu().numpy()]\n",
    "                next_neighbors = torch.LongTensor(next_neighbors).to(current_entities.device)\n",
    "\n",
    "                loss, entity_state_emb, logits, idx, chosen_relation = self.agent.step(\n",
    "                    next_possible_relations,\n",
    "                    next_possible_entities, entity_state_emb,\n",
    "                    prev_relation, query_relation,\n",
    "                    current_entities,\n",
    "                    head, link, tail,\n",
    "                    next_neighbors, prev_gnn_state\n",
    "                )\n",
    "\n",
    "                entity_state = entity_episode(idx.cpu())\n",
    "                next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "                prev_relation = chosen_relation.to(self.device)\n",
    "\n",
    "                all_losses.append(loss)\n",
    "                all_logits.append(logits)\n",
    "                all_action_id.append(idx)\n",
    "                path.append(current_entities)\n",
    "\n",
    "            rewards = entity_episode.get_reward()\n",
    "            rewards = torch.tensor(rewards).to(self.device)\n",
    "\n",
    "            cum_discounted_reward = self.calc_cum_discounted_reward(rewards)\n",
    "            reinforce_loss = self.calc_reinforce_loss(all_losses, all_logits, cum_discounted_reward,\n",
    "                                                        current_decay, self.baseline)\n",
    "\n",
    "            self.baseline.update(torch.mean(cum_discounted_reward))\n",
    "\n",
    "            self.optimizer.zero_grad()\n",
    "            reinforce_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=self.grad_clip_norm, norm_type=2)\n",
    "            self.optimizer.step()\n",
    "\n",
    "            train_loss.append(reinforce_loss.detach().cpu().item())\n",
    "            train_reward.append(rewards.cpu().float().tolist())\n",
    "\n",
    "            if (self.batch_counter > 0)&(self.batch_counter % (self.eval_every//10) == 0):\n",
    "                avg_loss = np.mean(train_loss[-(self.eval_every//10):])\n",
    "                avg_reward = np.mean(sum(train_reward[-(self.eval_every//10):], []))\n",
    "                print('Iteration: {}, Train loss: {:.4f}, rewards: {:.4f}'.format(self.batch_counter, avg_loss, avg_reward))\n",
    "                gc.collect()\n",
    "\n",
    "            if (self.batch_counter > 0)&(self.batch_counter % self.eval_every == 0):\n",
    "                print('Eval:')\n",
    "                self.test(beam = True)\n",
    "                gc.collect()\n",
    "                print('------------------------------------------------------------')\n",
    "\n",
    "            if self.batch_counter > self.total_iterations:\n",
    "                break\n",
    "\n",
    "    def test(self, beam=False, print_paths=False, save_model=True):\n",
    "\n",
    "        with torch.no_grad():\n",
    "\n",
    "            batch_counter = 0\n",
    "            paths = defaultdict(list)\n",
    "            answers = []\n",
    "            all_final_reward_1 = 0\n",
    "            all_final_reward_3 = 0\n",
    "            all_final_reward_5 = 0\n",
    "            all_final_reward_10 = 0\n",
    "            all_final_reward_20 = 0\n",
    "            auc = 0\n",
    "\n",
    "            total_examples = self.test_environment.total_no_examples\n",
    "\n",
    "            for entity_episode in self.test_environment.get_episodes():\n",
    "                batch_counter += 1\n",
    "\n",
    "                temp_batch_size = entity_episode.no_examples\n",
    "\n",
    "                self.qr = entity_episode.get_query_relation()\n",
    "                query_relation = self.qr\n",
    "                query_relation = torch.tensor(query_relation).long().to(self.device)\n",
    "                # set initial beam probs\n",
    "                beam_probs = torch.zeros((temp_batch_size * self.test_rollouts, 1)).to(self.device)\n",
    "\n",
    "                # get initial state for entity agent\n",
    "                entity_state = entity_episode.get_state()\n",
    "                \n",
    "                head, link, tail = self.get_graph(entity_episode, mode = 'test')\n",
    "                next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "\n",
    "                entity_state_emb = torch.zeros(1, 2, temp_batch_size * self.test_rollouts,\n",
    "                                               self.agent.m * self.hidden_size).to(self.device)\n",
    "                prev_relation = (torch.ones(temp_batch_size * self.test_rollouts) * self.relation_vocab[\n",
    "                    'DUMMY_START_RELATION']).long().to(self.device)\n",
    "                prev_gnn_state = self.agent.initial_gnn_state.repeat(len(current_entities), 1)\n",
    "\n",
    "                if print_paths:\n",
    "                    self.entity_trajectory = [current_entities]\n",
    "                    self.relation_trajectory = [prev_relation]   \n",
    "\n",
    "                self.log_probs = np.zeros((temp_batch_size * self.test_rollouts,)) * 1.0\n",
    "                for i in range(self.path_length):\n",
    "                    \n",
    "                    next_neighbors = self.test_environment.grapher.array_store[next_entities.cpu().numpy()]\n",
    "                    next_neighbors = torch.LongTensor(next_neighbors).to(current_entities.device)\n",
    "                    \n",
    "                    loss, entity_state_emb, test_scores, test_action_idx, chosen_relation = self.agent.step(\n",
    "                        next_relations,\n",
    "                        next_entities, entity_state_emb,\n",
    "                        prev_relation, query_relation,\n",
    "                        current_entities,\n",
    "                        head, link, tail,\n",
    "                        next_neighbors, prev_gnn_state\n",
    "                    )\n",
    "                    \n",
    "                    #Mimic original implementation on pytorch\n",
    "                    if beam:\n",
    "                        k = self.test_rollouts\n",
    "                        beam_probs = beam_probs.to(self.device)\n",
    "                        new_scores = test_scores + beam_probs\n",
    "                        new_scores = new_scores.cpu()\n",
    "                        if i == 0:\n",
    "                            idx = np.argsort(new_scores)\n",
    "                            idx = idx[:, -k:]\n",
    "                            ranged_idx = np.tile([b for b in range(k)], temp_batch_size)\n",
    "                            idx = idx[np.arange(k * temp_batch_size), ranged_idx]\n",
    "                        else:\n",
    "                            idx = self.top_k(new_scores, k)\n",
    "\n",
    "                        y = idx // self.max_num_actions\n",
    "                        x = idx % self.max_num_actions\n",
    "\n",
    "                        y += np.repeat([b * k for b in range(temp_batch_size)], k)\n",
    "                        entity_state['current_entities'] = entity_state['current_entities'][y]\n",
    "                        entity_state['next_relations'] = entity_state['next_relations'][y, :]\n",
    "                        entity_state['next_entities'] = entity_state['next_entities'][y, :]\n",
    "                        entity_state_emb = entity_state_emb[:, :, y, :]\n",
    "\n",
    "                        test_action_idx = x\n",
    "                        chosen_relation = entity_state['next_relations'][np.arange(temp_batch_size * k), x]\n",
    "\n",
    "                        beam_probs = new_scores[y, x]\n",
    "                        beam_probs = beam_probs.reshape((-1, 1))\n",
    "\n",
    "#                     #My implementation to fit arbitrary dimension\n",
    "#                     if beam:\n",
    "#                         k = self.test_rollouts\n",
    "#                         beam_probs = beam_probs.to(self.device)\n",
    "#                         new_scores = test_scores + beam_probs\n",
    "#                         new_scores = new_scores.cpu()\n",
    "#                         if i == 0:\n",
    "#                             reshape_score = new_scores.reshape(temp_batch_size, self.test_rollouts, -1)\n",
    "#                             possible_idx = []\n",
    "#                             for x in reshape_score:\n",
    "#                                 possible_idx.append(torch.LongTensor(np.where(x[0].cpu() > -1000)[0]))\n",
    "#                             idx = []\n",
    "#                             for x in possible_idx:\n",
    "#                                 idx.append(torch.cat(([x]*(self.test_rollouts//len(x) + 1)))[:self.test_rollouts])\n",
    "#                             idx = torch.cat(idx, 0)\n",
    "#                         else:\n",
    "#                             idx = self.top_k(new_scores, k)\n",
    "\n",
    "#                         y = idx // self.max_num_actions\n",
    "#                         x = idx % self.max_num_actions\n",
    "\n",
    "#                         y += np.repeat([b * k for b in range(temp_batch_size)], k)\n",
    "#                         entity_state['current_entities'] = entity_state['current_entities'][y]\n",
    "#                         entity_state['next_relations'] = entity_state['next_relations'][y, :]\n",
    "#                         entity_state['next_entities'] = entity_state['next_entities'][y, :]\n",
    "#                         entity_state_emb = entity_state_emb[:, :, y, :]\n",
    "\n",
    "#                         test_action_idx = x\n",
    "#                         chosen_relation = entity_state['next_relations'][np.arange(temp_batch_size * k), x]\n",
    "#                         beam_probs = new_scores[y, x]\n",
    "#                         beam_probs = beam_probs.reshape((-1, 1))\n",
    "                        \n",
    "                        if print_paths:\n",
    "                            for j in range(i):\n",
    "                                self.entity_trajectory[j] = self.entity_trajectory[j][y]\n",
    "                                self.relation_trajectory[j] = self.relation_trajectory[j][y]\n",
    "\n",
    "                    entity_state = entity_episode(test_action_idx.cpu().numpy())\n",
    "                    next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                    next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                    current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "                    prev_relation = torch.tensor(chosen_relation).long().to(self.device)\n",
    "                    \n",
    "                    if print_paths:\n",
    "                        self.entity_trajectory.append(entity_state['current_entities'])\n",
    "                        self.relation_trajectory.append(chosen_relation)\n",
    "\n",
    "                    test_scores = test_scores.cpu().numpy()\n",
    "                    self.log_probs += test_scores[np.arange(self.log_probs.shape[0]), test_action_idx.cpu().numpy()]\n",
    "\n",
    "                if beam:\n",
    "                    self.log_probs = beam_probs\n",
    "\n",
    "                rewards = entity_episode.get_reward()  # [B*test_rollouts]\n",
    "                reward_reshape = np.reshape(rewards, (temp_batch_size, self.test_rollouts))  # [orig_batch, test_rollouts]\n",
    "                self.log_probs = np.reshape(self.log_probs, (temp_batch_size, self.test_rollouts))\n",
    "                sorted_indx = np.argsort(-self.log_probs)\n",
    "                final_reward_1 = 0\n",
    "                final_reward_3 = 0\n",
    "                final_reward_5 = 0\n",
    "                final_reward_10 = 0\n",
    "                final_reward_20 = 0\n",
    "                AP = 0\n",
    "                ce = entity_episode.state['current_entities'].reshape((temp_batch_size, self.test_rollouts))\n",
    "                se = entity_episode.start_entities.reshape((temp_batch_size, self.test_rollouts))\n",
    "                for b in range(temp_batch_size):\n",
    "                    answer_pos = None\n",
    "                    seen = set()\n",
    "                    pos=0\n",
    "                    if self.pool == 'max':\n",
    "                        for r in sorted_indx[b]:\n",
    "                            if reward_reshape[b,r] == self.positive_reward:\n",
    "                                answer_pos = pos\n",
    "                                break\n",
    "                            if ce[b, r] not in seen:\n",
    "                                seen.add(ce[b, r])\n",
    "                                pos += 1\n",
    "                    if self.pool == 'sum':\n",
    "                        scores = defaultdict(list)\n",
    "                        answer = ''\n",
    "                        for r in sorted_indx[b]:\n",
    "                            scores[ce[b,r]].append(self.log_probs[b,r])\n",
    "                            if reward_reshape[b,r] == self.positive_reward:\n",
    "                                answer = ce[b,r]\n",
    "                        final_scores = defaultdict(float)\n",
    "                        for e in scores:\n",
    "                            final_scores[e] = lse(scores[e])\n",
    "                        sorted_answers = sorted(final_scores, key=final_scores.get, reverse=True)\n",
    "                        if answer in  sorted_answers:\n",
    "                            answer_pos = sorted_answers.index(answer)\n",
    "                        else:\n",
    "                            answer_pos = None\n",
    "\n",
    "\n",
    "                    if answer_pos != None:\n",
    "                        if answer_pos < 20:\n",
    "                            final_reward_20 += 1\n",
    "                            if answer_pos < 10:\n",
    "                                final_reward_10 += 1\n",
    "                                if answer_pos < 5:\n",
    "                                    final_reward_5 += 1\n",
    "                                    if answer_pos < 3:\n",
    "                                        final_reward_3 += 1\n",
    "                                        if answer_pos < 1:\n",
    "                                            final_reward_1 += 1\n",
    "                    if answer_pos == None:\n",
    "                        AP += 0\n",
    "                    else:\n",
    "                        AP += 1.0/((answer_pos+1))\n",
    "                    \n",
    "                    if print_paths:\n",
    "                        qr = self.train_environment.grapher.rev_relation_vocab[self.qr[b * self.test_rollouts]]\n",
    "                        start_e = self.rev_entity_vocab[entity_episode.start_entities[b * self.test_rollouts]]\n",
    "                        end_e = self.rev_entity_vocab[entity_episode.end_entities[b * self.test_rollouts]]\n",
    "                        paths[str(qr)].append(str(start_e) + \"\\t\" + str(end_e) + \"\\n\")\n",
    "                        paths[str(qr)].append(\"Reward:\" + str(1 if answer_pos != None and answer_pos < 10 else 0) + \"\\n\")\n",
    "                        for r in sorted_indx[b]:\n",
    "                            indx = b * self.test_rollouts + r\n",
    "                            if rewards[indx] == self.positive_reward:\n",
    "                                rev = 1\n",
    "                            else:\n",
    "                                rev = -1\n",
    "                            answers.append(self.rev_entity_vocab[se[b,r]]+'\\t'+ self.rev_entity_vocab[ce[b,r]]+'\\t'+ str(self.log_probs[b,r])+'\\n')\n",
    "                            paths[str(qr)].append(\n",
    "                                '\\t'.join([str(self.rev_entity_vocab[e[indx]]) for e in\n",
    "                                           self.entity_trajectory]) + '\\n' + '\\t'.join(\n",
    "                                    [str(self.rev_relation_vocab[re[indx]]) for re in self.relation_trajectory]) + '\\n' + str(\n",
    "                                    rev) + '\\n' + str(\n",
    "                                    self.log_probs[b, r]) + '\\n___' + '\\n')\n",
    "                        paths[str(qr)].append(\"#####################\\n\")\n",
    "\n",
    "                all_final_reward_1 += final_reward_1\n",
    "                all_final_reward_3 += final_reward_3\n",
    "                all_final_reward_5 += final_reward_5\n",
    "                all_final_reward_10 += final_reward_10\n",
    "                all_final_reward_20 += final_reward_20\n",
    "                auc += AP\n",
    "\n",
    "            all_final_reward_1 /= total_examples\n",
    "            all_final_reward_3 /= total_examples\n",
    "            all_final_reward_5 /= total_examples\n",
    "            all_final_reward_10 /= total_examples\n",
    "            all_final_reward_20 /= total_examples\n",
    "            auc /= total_examples\n",
    "            \n",
    "            if save_model:\n",
    "                if all_final_reward_10 >= self.max_hits_at_10:\n",
    "                    self.max_hits_at_10 = all_final_reward_10\n",
    "                    torch.save(self.agent.state_dict(), self.model_dir + \"agent\" + '.ckpt')\n",
    "                    # self.save_path = self.model_dir + \"model\" + '.ckpt'\n",
    "\n",
    "            if print_paths:\n",
    "                logger.info(\"[ printing paths at {} ]\".format(self.output_dir + '/test_beam/'))\n",
    "                for q in paths:\n",
    "                    j = q.replace('/', '-')\n",
    "                    with codecs.open(self.path_logger_file_ + '_' + j, 'a', 'utf-8') as pos_file:\n",
    "                        for p in paths[q]:\n",
    "                            pos_file.write(p)\n",
    "                with open(self.path_logger_file_ + 'answers', 'w') as answer_file:\n",
    "                    for a in answers:\n",
    "                        answer_file.write(a)\n",
    "\n",
    "            with open(self.output_dir + '/scores.txt', 'a') as score_file:\n",
    "                score_file.write(\"Hits@1: {:.4f}\".format(all_final_reward_1))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@3: {:.4f}\".format(all_final_reward_3))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@5: {:.4f}\".format(all_final_reward_5))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@10: {:.4f}\".format(all_final_reward_10))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@20: {:.4f}\".format(all_final_reward_20))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"MRR: {:.4f}\".format(auc))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"------------------------------------\")\n",
    "\n",
    "            print(\"Hits@1: {:.4f}, Hits@3: {:.4f}, Hits@10: {:.4f}, MRR: {:.4f}\".format(all_final_reward_1, \n",
    "                                                                                 all_final_reward_3,\n",
    "                                                                                 all_final_reward_10, auc))\n",
    "            \n",
    "    def top_k(self, scores, k):\n",
    "        scores = scores.reshape(-1, k * self.max_num_actions)  # [B, (k*max_num_actions)]\n",
    "        idx = np.argsort(scores, axis=1)\n",
    "        idx = idx[:, -k:]  # take the last k highest indices # [B , k]\n",
    "        return idx.reshape((-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8bca5f0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = {}\n",
    "\n",
    "#basic setting\n",
    "options['use_cuda'] = True\n",
    "options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/FB15K-237/vocab/'\n",
    "options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/FB15K-237/'\n",
    "options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "options['model_dir'] = './outputs_FB15K-237_v7/'\n",
    "options['output_dir'] = './outputs_FB15K-237_v7/'\n",
    "\n",
    "#agent setting\n",
    "options['pretrained_embeddings_relation'] = {}\n",
    "options['pretrained_embeddings_entity'] = {}\n",
    "options['embedding_size'] = 50\n",
    "options['hidden_size'] = 200\n",
    "options['use_entity_embeddings'] = 1\n",
    "options['train_entity_embeddings'] = 0\n",
    "options['train_relation_embeddings'] = 1\n",
    "options['path_length'] = 3\n",
    "options['LSTM_layers'] = 1\n",
    "options['max_num_actions'] = 100\n",
    "\n",
    "#hyperparameters\n",
    "options['test_rollouts'] = 100\n",
    "options['num_rollouts'] = 20\n",
    "options['batch_size'] = 8\n",
    "options['beta'] = 0.05\n",
    "options['Lambda'] = 0.05\n",
    "options['gamma'] = 1\n",
    "options['positive_reward'] = 1\n",
    "options['negative_reward'] = 0\n",
    "options['learning_rate'] = 5e-4\n",
    "options['grad_clip_norm'] = 100\n",
    "options['eval_every'] = 100\n",
    "options['total_iterations'] = 16000\n",
    "options['pool'] = 'max'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dc1e4ea",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/cwong/miniconda3/envs/tracy/lib/python3.9/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/tmp/ipykernel_2375638/1090773659.py:224: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.1961, rewards: 0.0056\n",
      "Iteration: 20, Train loss: -0.3018, rewards: 0.0381\n",
      "Iteration: 30, Train loss: -0.2450, rewards: 0.0256\n",
      "Iteration: 40, Train loss: -0.2230, rewards: 0.0156\n",
      "Iteration: 50, Train loss: -0.2544, rewards: 0.0331\n",
      "Iteration: 60, Train loss: -0.3190, rewards: 0.0494\n",
      "Iteration: 70, Train loss: -0.2311, rewards: 0.0163\n",
      "Iteration: 80, Train loss: -0.2593, rewards: 0.0262\n",
      "Iteration: 90, Train loss: -0.2299, rewards: 0.0169\n",
      "Iteration: 100, Train loss: -0.2167, rewards: 0.0163\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2375638/598236006.py:327: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.0131, Hits@3: 0.0362, Hits@10: 0.0836, MRR: 0.0352\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.2406, rewards: 0.0275\n",
      "Iteration: 120, Train loss: -0.2514, rewards: 0.0213\n",
      "Iteration: 130, Train loss: -0.2796, rewards: 0.0231\n",
      "Iteration: 140, Train loss: -0.2289, rewards: 0.0288\n",
      "Iteration: 150, Train loss: -0.1995, rewards: 0.0369\n",
      "Iteration: 160, Train loss: -0.2568, rewards: 0.0275\n",
      "Iteration: 170, Train loss: -0.2316, rewards: 0.0200\n",
      "Iteration: 180, Train loss: -0.2200, rewards: 0.0163\n",
      "Iteration: 190, Train loss: -0.2035, rewards: 0.0125\n",
      "Iteration: 200, Train loss: -0.2902, rewards: 0.0312\n",
      "Eval:\n",
      "Hits@1: 0.0072, Hits@3: 0.0209, Hits@10: 0.0609, MRR: 0.0249\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.2771, rewards: 0.0312\n",
      "Iteration: 220, Train loss: -0.2891, rewards: 0.0344\n",
      "Iteration: 230, Train loss: -0.2965, rewards: 0.0537\n",
      "Iteration: 240, Train loss: -0.2498, rewards: 0.0244\n",
      "Iteration: 250, Train loss: -0.2512, rewards: 0.0163\n",
      "Iteration: 260, Train loss: -0.2295, rewards: 0.0187\n",
      "Iteration: 270, Train loss: -0.2548, rewards: 0.0288\n",
      "Iteration: 280, Train loss: -0.1943, rewards: 0.0088\n",
      "Iteration: 290, Train loss: -0.2775, rewards: 0.0219\n",
      "Iteration: 300, Train loss: -0.2684, rewards: 0.0175\n",
      "Eval:\n",
      "Hits@1: 0.0068, Hits@3: 0.0206, Hits@10: 0.0567, MRR: 0.0237\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.3011, rewards: 0.0269\n",
      "Iteration: 320, Train loss: -0.2104, rewards: 0.0119\n",
      "Iteration: 330, Train loss: -0.2713, rewards: 0.0312\n",
      "Iteration: 340, Train loss: -0.3453, rewards: 0.0387\n",
      "Iteration: 350, Train loss: -0.3906, rewards: 0.0500\n",
      "Iteration: 360, Train loss: -0.2666, rewards: 0.0219\n",
      "Iteration: 370, Train loss: -0.2573, rewards: 0.0244\n",
      "Iteration: 380, Train loss: -0.2994, rewards: 0.0300\n",
      "Iteration: 390, Train loss: -0.2840, rewards: 0.0462\n",
      "Iteration: 400, Train loss: -0.3023, rewards: 0.0475\n",
      "Eval:\n",
      "Hits@1: 0.0207, Hits@3: 0.0455, Hits@10: 0.0987, MRR: 0.0451\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.1664, rewards: 0.0288\n",
      "Iteration: 420, Train loss: -0.3270, rewards: 0.0387\n",
      "Iteration: 430, Train loss: -0.2850, rewards: 0.0219\n",
      "Iteration: 440, Train loss: -0.2865, rewards: 0.0187\n",
      "Iteration: 450, Train loss: -0.3773, rewards: 0.0631\n",
      "Iteration: 460, Train loss: -0.3021, rewards: 0.0200\n",
      "Iteration: 470, Train loss: -0.4168, rewards: 0.0456\n",
      "Iteration: 480, Train loss: -0.2539, rewards: 0.0150\n",
      "Iteration: 490, Train loss: -0.2614, rewards: 0.0300\n",
      "Iteration: 500, Train loss: -0.2489, rewards: 0.0262\n",
      "Eval:\n",
      "Hits@1: 0.0142, Hits@3: 0.0309, Hits@10: 0.0819, MRR: 0.0343\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.2625, rewards: 0.0175\n",
      "Iteration: 520, Train loss: -0.2686, rewards: 0.0319\n",
      "Iteration: 530, Train loss: -0.2800, rewards: 0.0269\n",
      "Iteration: 540, Train loss: -0.3737, rewards: 0.0431\n",
      "Iteration: 550, Train loss: -0.2515, rewards: 0.0163\n",
      "Iteration: 560, Train loss: -0.2886, rewards: 0.0213\n",
      "Iteration: 570, Train loss: -0.2733, rewards: 0.0319\n",
      "Iteration: 580, Train loss: -0.2300, rewards: 0.0106\n",
      "Iteration: 590, Train loss: -0.1915, rewards: 0.0200\n",
      "Iteration: 600, Train loss: -0.2084, rewards: 0.0213\n",
      "Eval:\n",
      "Hits@1: 0.0069, Hits@3: 0.0194, Hits@10: 0.0532, MRR: 0.0218\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.2334, rewards: 0.0163\n",
      "Iteration: 620, Train loss: -0.2646, rewards: 0.0262\n",
      "Iteration: 630, Train loss: -0.2268, rewards: 0.0300\n",
      "Iteration: 640, Train loss: -0.2712, rewards: 0.0275\n",
      "Iteration: 650, Train loss: -0.2258, rewards: 0.0169\n",
      "Iteration: 660, Train loss: -0.3164, rewards: 0.0444\n",
      "Iteration: 670, Train loss: -0.2619, rewards: 0.0163\n",
      "Iteration: 680, Train loss: -0.2194, rewards: 0.0200\n",
      "Iteration: 690, Train loss: -0.3458, rewards: 0.0425\n",
      "Iteration: 700, Train loss: -0.2350, rewards: 0.0156\n",
      "Eval:\n",
      "Hits@1: 0.0086, Hits@3: 0.0281, Hits@10: 0.0790, MRR: 0.0302\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.3375, rewards: 0.0456\n",
      "Iteration: 720, Train loss: -0.2347, rewards: 0.0119\n",
      "Iteration: 730, Train loss: -0.3066, rewards: 0.0338\n",
      "Iteration: 740, Train loss: -0.3013, rewards: 0.0300\n",
      "Iteration: 750, Train loss: -0.2959, rewards: 0.0425\n",
      "Iteration: 760, Train loss: -0.2519, rewards: 0.0288\n",
      "Iteration: 770, Train loss: -0.2940, rewards: 0.0619\n",
      "Iteration: 780, Train loss: -0.2564, rewards: 0.0200\n",
      "Iteration: 790, Train loss: -0.2562, rewards: 0.0181\n",
      "Iteration: 800, Train loss: -0.2380, rewards: 0.0256\n",
      "Eval:\n",
      "Hits@1: 0.0067, Hits@3: 0.0183, Hits@10: 0.0539, MRR: 0.0227\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.3301, rewards: 0.0381\n",
      "Iteration: 820, Train loss: -0.2664, rewards: 0.0262\n",
      "Iteration: 830, Train loss: -0.2085, rewards: 0.0106\n",
      "Iteration: 840, Train loss: -0.2544, rewards: 0.0219\n",
      "Iteration: 850, Train loss: -0.2594, rewards: 0.0319\n",
      "Iteration: 860, Train loss: -0.3444, rewards: 0.0338\n",
      "Iteration: 870, Train loss: -0.2482, rewards: 0.0206\n",
      "Iteration: 880, Train loss: -0.3294, rewards: 0.0406\n",
      "Iteration: 890, Train loss: -0.3105, rewards: 0.0519\n",
      "Iteration: 900, Train loss: -0.2009, rewards: 0.0281\n",
      "Eval:\n",
      "Hits@1: 0.0218, Hits@3: 0.0556, Hits@10: 0.1026, MRR: 0.0473\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.3666, rewards: 0.0338\n",
      "Iteration: 920, Train loss: -0.2275, rewards: 0.0144\n",
      "Iteration: 930, Train loss: -0.2903, rewards: 0.0294\n",
      "Iteration: 940, Train loss: -0.2457, rewards: 0.0194\n",
      "Iteration: 950, Train loss: -0.3407, rewards: 0.0688\n",
      "Iteration: 960, Train loss: -0.3217, rewards: 0.0431\n",
      "Iteration: 970, Train loss: -0.2467, rewards: 0.0200\n",
      "Iteration: 980, Train loss: -0.2092, rewards: 0.0300\n",
      "Iteration: 990, Train loss: -0.2182, rewards: 0.0213\n",
      "Iteration: 1000, Train loss: -0.2868, rewards: 0.0406\n",
      "Eval:\n",
      "Hits@1: 0.0181, Hits@3: 0.0499, Hits@10: 0.1198, MRR: 0.0480\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.2584, rewards: 0.0194\n",
      "Iteration: 1020, Train loss: -0.2579, rewards: 0.0300\n",
      "Iteration: 1030, Train loss: -0.2435, rewards: 0.0325\n",
      "Iteration: 1040, Train loss: -0.2238, rewards: 0.0275\n",
      "Iteration: 1050, Train loss: -0.2450, rewards: 0.0262\n",
      "Iteration: 1060, Train loss: -0.2782, rewards: 0.0288\n",
      "Iteration: 1070, Train loss: -0.3495, rewards: 0.0387\n",
      "Iteration: 1080, Train loss: -0.3826, rewards: 0.0356\n",
      "Iteration: 1090, Train loss: -0.2182, rewards: 0.0138\n",
      "Iteration: 1100, Train loss: -0.3630, rewards: 0.0325\n",
      "Eval:\n",
      "Hits@1: 0.0164, Hits@3: 0.0389, Hits@10: 0.1009, MRR: 0.0420\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.3073, rewards: 0.0394\n",
      "Iteration: 1120, Train loss: -0.3418, rewards: 0.0356\n",
      "Iteration: 1130, Train loss: -0.3137, rewards: 0.0450\n",
      "Iteration: 1140, Train loss: -0.2791, rewards: 0.0213\n",
      "Iteration: 1150, Train loss: -0.4543, rewards: 0.0619\n",
      "Iteration: 1160, Train loss: -0.3201, rewards: 0.0506\n",
      "Iteration: 1170, Train loss: -0.4333, rewards: 0.0462\n",
      "Iteration: 1180, Train loss: -0.2710, rewards: 0.0294\n",
      "Iteration: 1190, Train loss: -0.3117, rewards: 0.0663\n",
      "Iteration: 1200, Train loss: -0.2948, rewards: 0.0350\n",
      "Eval:\n",
      "Hits@1: 0.0189, Hits@3: 0.0512, Hits@10: 0.1127, MRR: 0.0472\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.2997, rewards: 0.0638\n",
      "Iteration: 1220, Train loss: -0.2883, rewards: 0.0550\n",
      "Iteration: 1230, Train loss: -0.2735, rewards: 0.0344\n",
      "Iteration: 1240, Train loss: -0.2370, rewards: 0.0256\n",
      "Iteration: 1250, Train loss: -0.2404, rewards: 0.0194\n",
      "Iteration: 1260, Train loss: -0.3759, rewards: 0.0825\n",
      "Iteration: 1270, Train loss: -0.2464, rewards: 0.0394\n",
      "Iteration: 1280, Train loss: -0.2942, rewards: 0.0425\n",
      "Iteration: 1290, Train loss: -0.2821, rewards: 0.0719\n",
      "Iteration: 1300, Train loss: -0.3008, rewards: 0.0569\n",
      "Eval:\n",
      "Hits@1: 0.0285, Hits@3: 0.0681, Hits@10: 0.1448, MRR: 0.0622\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: -0.2456, rewards: 0.0344\n",
      "Iteration: 1320, Train loss: -0.3048, rewards: 0.0744\n",
      "Iteration: 1330, Train loss: -0.2701, rewards: 0.0431\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1340, Train loss: -0.2667, rewards: 0.0431\n",
      "Iteration: 1350, Train loss: -0.3731, rewards: 0.0731\n",
      "Iteration: 1360, Train loss: -0.1725, rewards: 0.0262\n",
      "Iteration: 1370, Train loss: -0.1957, rewards: 0.0219\n",
      "Iteration: 1380, Train loss: -0.2900, rewards: 0.0481\n",
      "Iteration: 1390, Train loss: -0.2545, rewards: 0.0344\n",
      "Iteration: 1400, Train loss: -0.2537, rewards: 0.0362\n",
      "Eval:\n",
      "Hits@1: 0.0150, Hits@3: 0.0405, Hits@10: 0.0903, MRR: 0.0380\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.2754, rewards: 0.0556\n",
      "Iteration: 1420, Train loss: -0.2109, rewards: 0.0419\n",
      "Iteration: 1430, Train loss: -0.2734, rewards: 0.0419\n",
      "Iteration: 1440, Train loss: -0.2092, rewards: 0.0250\n",
      "Iteration: 1450, Train loss: -0.3018, rewards: 0.0650\n",
      "Iteration: 1460, Train loss: -0.2419, rewards: 0.0269\n",
      "Iteration: 1470, Train loss: -0.2238, rewards: 0.0356\n",
      "Iteration: 1480, Train loss: -0.2024, rewards: 0.0169\n",
      "Iteration: 1490, Train loss: -0.2998, rewards: 0.0606\n",
      "Iteration: 1500, Train loss: -0.3030, rewards: 0.0356\n",
      "Eval:\n",
      "Hits@1: 0.0094, Hits@3: 0.0341, Hits@10: 0.0963, MRR: 0.0350\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.1933, rewards: 0.0269\n",
      "Iteration: 1520, Train loss: -0.2060, rewards: 0.0312\n",
      "Iteration: 1530, Train loss: -0.2576, rewards: 0.0381\n",
      "Iteration: 1540, Train loss: -0.2470, rewards: 0.0213\n",
      "Iteration: 1550, Train loss: -0.2034, rewards: 0.0131\n",
      "Iteration: 1560, Train loss: -0.3077, rewards: 0.0381\n",
      "Iteration: 1570, Train loss: -0.2369, rewards: 0.0138\n",
      "Iteration: 1580, Train loss: -0.3043, rewards: 0.0381\n",
      "Iteration: 1590, Train loss: -0.3218, rewards: 0.0544\n",
      "Iteration: 1600, Train loss: -0.2323, rewards: 0.0413\n",
      "Eval:\n",
      "Hits@1: 0.0207, Hits@3: 0.0545, Hits@10: 0.1211, MRR: 0.0500\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.2612, rewards: 0.0456\n",
      "Iteration: 1620, Train loss: -0.3552, rewards: 0.0581\n",
      "Iteration: 1630, Train loss: -0.2655, rewards: 0.0356\n",
      "Iteration: 1640, Train loss: -0.3534, rewards: 0.0456\n",
      "Iteration: 1650, Train loss: -0.3386, rewards: 0.0431\n",
      "Iteration: 1660, Train loss: -0.3690, rewards: 0.0669\n",
      "Iteration: 1670, Train loss: -0.2824, rewards: 0.0369\n",
      "Iteration: 1680, Train loss: -0.2321, rewards: 0.0219\n",
      "Iteration: 1690, Train loss: -0.2862, rewards: 0.0544\n",
      "Iteration: 1700, Train loss: -0.2325, rewards: 0.0200\n",
      "Eval:\n",
      "Hits@1: 0.0161, Hits@3: 0.0500, Hits@10: 0.1100, MRR: 0.0443\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.2007, rewards: 0.0213\n",
      "Iteration: 1720, Train loss: -0.3004, rewards: 0.0512\n",
      "Iteration: 1730, Train loss: -0.2758, rewards: 0.0319\n",
      "Iteration: 1740, Train loss: -0.3244, rewards: 0.0519\n",
      "Iteration: 1750, Train loss: -0.2449, rewards: 0.0338\n",
      "Iteration: 1760, Train loss: -0.3466, rewards: 0.0500\n",
      "Iteration: 1770, Train loss: -0.4513, rewards: 0.0944\n",
      "Iteration: 1780, Train loss: -0.2327, rewards: 0.0431\n",
      "Iteration: 1790, Train loss: -0.3023, rewards: 0.0231\n",
      "Iteration: 1800, Train loss: -0.2668, rewards: 0.0250\n",
      "Eval:\n",
      "Hits@1: 0.0114, Hits@3: 0.0298, Hits@10: 0.0789, MRR: 0.0321\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.2502, rewards: 0.0369\n",
      "Iteration: 1820, Train loss: -0.3553, rewards: 0.0481\n",
      "Iteration: 1830, Train loss: -0.3117, rewards: 0.0387\n",
      "Iteration: 1840, Train loss: -0.2891, rewards: 0.0375\n",
      "Iteration: 1850, Train loss: -0.3486, rewards: 0.0475\n",
      "Iteration: 1860, Train loss: -0.3781, rewards: 0.0519\n",
      "Iteration: 1870, Train loss: -0.3241, rewards: 0.0531\n",
      "Iteration: 1880, Train loss: -0.2706, rewards: 0.0525\n",
      "Iteration: 1890, Train loss: -0.2215, rewards: 0.0519\n",
      "Iteration: 1900, Train loss: -0.3058, rewards: 0.0631\n",
      "Eval:\n",
      "Hits@1: 0.0201, Hits@3: 0.0487, Hits@10: 0.1048, MRR: 0.0460\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.3005, rewards: 0.0663\n",
      "Iteration: 1920, Train loss: -0.3727, rewards: 0.0869\n",
      "Iteration: 1930, Train loss: -0.3052, rewards: 0.0462\n",
      "Iteration: 1940, Train loss: -0.3588, rewards: 0.0531\n",
      "Iteration: 1950, Train loss: -0.3209, rewards: 0.0650\n",
      "Iteration: 1960, Train loss: -0.2878, rewards: 0.0312\n",
      "Iteration: 1970, Train loss: -0.3150, rewards: 0.0612\n",
      "Iteration: 1980, Train loss: -0.3560, rewards: 0.0506\n",
      "Iteration: 1990, Train loss: -0.3082, rewards: 0.0406\n",
      "Iteration: 2000, Train loss: -0.2464, rewards: 0.0194\n",
      "Eval:\n",
      "Hits@1: 0.0175, Hits@3: 0.0468, Hits@10: 0.0988, MRR: 0.0425\n",
      "------------------------------------------------------------\n",
      "Iteration: 2010, Train loss: -0.2426, rewards: 0.0256\n",
      "Iteration: 2020, Train loss: -0.2230, rewards: 0.0175\n",
      "Iteration: 2030, Train loss: -0.2607, rewards: 0.0394\n",
      "Iteration: 2040, Train loss: -0.2542, rewards: 0.0419\n",
      "Iteration: 2050, Train loss: -0.3276, rewards: 0.0456\n",
      "Iteration: 2060, Train loss: -0.2683, rewards: 0.0700\n",
      "Iteration: 2070, Train loss: -0.3376, rewards: 0.0663\n",
      "Iteration: 2080, Train loss: -0.3378, rewards: 0.0694\n",
      "Iteration: 2090, Train loss: -0.4082, rewards: 0.1025\n",
      "Iteration: 2100, Train loss: -0.3460, rewards: 0.0819\n",
      "Eval:\n",
      "Hits@1: 0.0266, Hits@3: 0.0583, Hits@10: 0.1130, MRR: 0.0538\n",
      "------------------------------------------------------------\n",
      "Iteration: 2110, Train loss: -0.3848, rewards: 0.0875\n",
      "Iteration: 2120, Train loss: -0.2158, rewards: 0.0325\n",
      "Iteration: 2130, Train loss: -0.2327, rewards: 0.0431\n",
      "Iteration: 2140, Train loss: -0.2690, rewards: 0.0306\n",
      "Iteration: 2150, Train loss: -0.2833, rewards: 0.0400\n",
      "Iteration: 2160, Train loss: -0.2909, rewards: 0.0469\n",
      "Iteration: 2170, Train loss: -0.3121, rewards: 0.0506\n",
      "Iteration: 2180, Train loss: -0.3521, rewards: 0.0481\n",
      "Iteration: 2190, Train loss: -0.3984, rewards: 0.0725\n",
      "Iteration: 2200, Train loss: -0.3568, rewards: 0.0369\n",
      "Eval:\n",
      "Hits@1: 0.0245, Hits@3: 0.0552, Hits@10: 0.1163, MRR: 0.0529\n",
      "------------------------------------------------------------\n",
      "Iteration: 2210, Train loss: -0.3017, rewards: 0.0481\n",
      "Iteration: 2220, Train loss: -0.3163, rewards: 0.0525\n",
      "Iteration: 2230, Train loss: -0.2489, rewards: 0.0225\n",
      "Iteration: 2240, Train loss: -0.3389, rewards: 0.0406\n",
      "Iteration: 2250, Train loss: -0.3896, rewards: 0.0431\n",
      "Iteration: 2260, Train loss: -0.3949, rewards: 0.0475\n",
      "Iteration: 2270, Train loss: -0.3329, rewards: 0.0450\n",
      "Iteration: 2280, Train loss: -0.3705, rewards: 0.0394\n",
      "Iteration: 2290, Train loss: -0.3527, rewards: 0.0494\n",
      "Iteration: 2300, Train loss: -0.2850, rewards: 0.0331\n",
      "Eval:\n",
      "Hits@1: 0.0213, Hits@3: 0.0576, Hits@10: 0.1266, MRR: 0.0530\n",
      "------------------------------------------------------------\n",
      "Iteration: 2310, Train loss: -0.3150, rewards: 0.0419\n",
      "Iteration: 2320, Train loss: -0.1952, rewards: 0.0231\n",
      "Iteration: 2330, Train loss: -0.2980, rewards: 0.0400\n",
      "Iteration: 2340, Train loss: -0.3237, rewards: 0.0462\n",
      "Iteration: 2350, Train loss: -0.2998, rewards: 0.0444\n",
      "Iteration: 2360, Train loss: -0.3799, rewards: 0.0425\n",
      "Iteration: 2370, Train loss: -0.2843, rewards: 0.0300\n",
      "Iteration: 2380, Train loss: -0.2672, rewards: 0.0406\n",
      "Iteration: 2390, Train loss: -0.3195, rewards: 0.0563\n",
      "Iteration: 2400, Train loss: -0.2473, rewards: 0.0362\n",
      "Eval:\n",
      "Hits@1: 0.0239, Hits@3: 0.0528, Hits@10: 0.1111, MRR: 0.0526\n",
      "------------------------------------------------------------\n",
      "Iteration: 2410, Train loss: -0.3279, rewards: 0.0413\n",
      "Iteration: 2420, Train loss: -0.3433, rewards: 0.0356\n",
      "Iteration: 2430, Train loss: -0.4436, rewards: 0.0581\n",
      "Iteration: 2440, Train loss: -0.3012, rewards: 0.0375\n",
      "Iteration: 2450, Train loss: -0.2368, rewards: 0.0362\n",
      "Iteration: 2460, Train loss: -0.2966, rewards: 0.0569\n",
      "Iteration: 2470, Train loss: -0.3087, rewards: 0.0525\n",
      "Iteration: 2480, Train loss: -0.2769, rewards: 0.0375\n",
      "Iteration: 2490, Train loss: -0.3451, rewards: 0.0425\n",
      "Iteration: 2500, Train loss: -0.3211, rewards: 0.1056\n",
      "Eval:\n",
      "Hits@1: 0.0176, Hits@3: 0.0483, Hits@10: 0.1022, MRR: 0.0423\n",
      "------------------------------------------------------------\n",
      "Iteration: 2510, Train loss: -0.2458, rewards: 0.0519\n",
      "Iteration: 2520, Train loss: -0.3577, rewards: 0.0544\n",
      "Iteration: 2530, Train loss: -0.2871, rewards: 0.0475\n",
      "Iteration: 2540, Train loss: -0.3324, rewards: 0.0375\n",
      "Iteration: 2550, Train loss: -0.2254, rewards: 0.0475\n",
      "Iteration: 2560, Train loss: -0.3337, rewards: 0.0431\n",
      "Iteration: 2570, Train loss: -0.4565, rewards: 0.0775\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 2580, Train loss: -0.3595, rewards: 0.0525\n",
      "Iteration: 2590, Train loss: -0.3447, rewards: 0.0550\n",
      "Iteration: 2600, Train loss: -0.4487, rewards: 0.0462\n",
      "Eval:\n",
      "Hits@1: 0.0215, Hits@3: 0.0503, Hits@10: 0.1165, MRR: 0.0511\n",
      "------------------------------------------------------------\n",
      "Iteration: 2610, Train loss: -0.3339, rewards: 0.0462\n",
      "Iteration: 2620, Train loss: -0.3453, rewards: 0.0475\n",
      "Iteration: 2630, Train loss: -0.2621, rewards: 0.0238\n",
      "Iteration: 2640, Train loss: -0.3656, rewards: 0.0475\n",
      "Iteration: 2650, Train loss: -0.3017, rewards: 0.0462\n",
      "Iteration: 2660, Train loss: -0.3460, rewards: 0.0675\n",
      "Iteration: 2670, Train loss: -0.3807, rewards: 0.0719\n",
      "Iteration: 2680, Train loss: -0.2650, rewards: 0.0425\n",
      "Iteration: 2690, Train loss: -0.2800, rewards: 0.0488\n",
      "Iteration: 2700, Train loss: -0.2438, rewards: 0.0306\n",
      "Eval:\n",
      "Hits@1: 0.0363, Hits@3: 0.0825, Hits@10: 0.1829, MRR: 0.0786\n",
      "------------------------------------------------------------\n",
      "Iteration: 2710, Train loss: -0.4688, rewards: 0.0688\n",
      "Iteration: 2720, Train loss: -0.3247, rewards: 0.0525\n",
      "Iteration: 2730, Train loss: -0.3269, rewards: 0.0663\n",
      "Iteration: 2740, Train loss: -0.3453, rewards: 0.0612\n",
      "Iteration: 2750, Train loss: -0.3118, rewards: 0.0419\n",
      "Iteration: 2760, Train loss: -0.2450, rewards: 0.0419\n",
      "Iteration: 2770, Train loss: -0.3818, rewards: 0.0575\n",
      "Iteration: 2780, Train loss: -0.2397, rewards: 0.0219\n",
      "Iteration: 2790, Train loss: -0.3813, rewards: 0.0612\n",
      "Iteration: 2800, Train loss: -0.4155, rewards: 0.0544\n",
      "Eval:\n",
      "Hits@1: 0.0288, Hits@3: 0.0685, Hits@10: 0.1401, MRR: 0.0644\n",
      "------------------------------------------------------------\n",
      "Iteration: 2810, Train loss: -0.3721, rewards: 0.0456\n",
      "Iteration: 2820, Train loss: -0.2603, rewards: 0.0194\n",
      "Iteration: 2830, Train loss: -0.2587, rewards: 0.0369\n",
      "Iteration: 2840, Train loss: -0.2926, rewards: 0.0400\n",
      "Iteration: 2850, Train loss: -0.2877, rewards: 0.0306\n",
      "Iteration: 2860, Train loss: -0.2937, rewards: 0.0331\n",
      "Iteration: 2870, Train loss: -0.4144, rewards: 0.0475\n",
      "Iteration: 2880, Train loss: -0.4453, rewards: 0.0625\n",
      "Iteration: 2890, Train loss: -0.3788, rewards: 0.0731\n",
      "Iteration: 2900, Train loss: -0.3439, rewards: 0.0681\n",
      "Eval:\n",
      "Hits@1: 0.0391, Hits@3: 0.0912, Hits@10: 0.1921, MRR: 0.0845\n",
      "------------------------------------------------------------\n",
      "Iteration: 2910, Train loss: -0.3301, rewards: 0.0387\n",
      "Iteration: 2920, Train loss: -0.3075, rewards: 0.0475\n",
      "Iteration: 2930, Train loss: -0.3949, rewards: 0.0525\n",
      "Iteration: 2940, Train loss: -0.2728, rewards: 0.0594\n",
      "Iteration: 2950, Train loss: -0.3092, rewards: 0.0312\n",
      "Iteration: 2960, Train loss: -0.2573, rewards: 0.0500\n",
      "Iteration: 2970, Train loss: -0.3318, rewards: 0.0544\n",
      "Iteration: 2980, Train loss: -0.2805, rewards: 0.0706\n",
      "Iteration: 2990, Train loss: -0.2512, rewards: 0.0288\n",
      "Iteration: 3000, Train loss: -0.4276, rewards: 0.0750\n",
      "Eval:\n",
      "Hits@1: 0.0260, Hits@3: 0.0623, Hits@10: 0.1262, MRR: 0.0583\n",
      "------------------------------------------------------------\n",
      "Iteration: 3010, Train loss: -0.3027, rewards: 0.0413\n",
      "Iteration: 3020, Train loss: -0.2925, rewards: 0.0581\n",
      "Iteration: 3030, Train loss: -0.2454, rewards: 0.0419\n",
      "Iteration: 3040, Train loss: -0.2261, rewards: 0.0469\n",
      "Iteration: 3050, Train loss: -0.2775, rewards: 0.0512\n",
      "Iteration: 3060, Train loss: -0.2485, rewards: 0.0437\n",
      "Iteration: 3070, Train loss: -0.3036, rewards: 0.0369\n",
      "Iteration: 3080, Train loss: -0.3146, rewards: 0.0706\n",
      "Iteration: 3090, Train loss: -0.2597, rewards: 0.0556\n",
      "Iteration: 3100, Train loss: -0.3613, rewards: 0.0581\n",
      "Eval:\n",
      "Hits@1: 0.0402, Hits@3: 0.0989, Hits@10: 0.1862, MRR: 0.0845\n",
      "------------------------------------------------------------\n",
      "Iteration: 3110, Train loss: -0.2807, rewards: 0.0638\n",
      "Iteration: 3120, Train loss: -0.4327, rewards: 0.0606\n",
      "Iteration: 3130, Train loss: -0.4267, rewards: 0.0762\n",
      "Iteration: 3140, Train loss: -0.4615, rewards: 0.0737\n",
      "Iteration: 3150, Train loss: -0.3544, rewards: 0.0512\n",
      "Iteration: 3160, Train loss: -0.4485, rewards: 0.0875\n",
      "Iteration: 3170, Train loss: -0.3148, rewards: 0.0606\n",
      "Iteration: 3180, Train loss: -0.3596, rewards: 0.0612\n",
      "Iteration: 3190, Train loss: -0.3847, rewards: 0.0425\n",
      "Iteration: 3200, Train loss: -0.2648, rewards: 0.0262\n",
      "Eval:\n",
      "Hits@1: 0.0358, Hits@3: 0.0846, Hits@10: 0.1631, MRR: 0.0744\n",
      "------------------------------------------------------------\n",
      "Iteration: 3210, Train loss: -0.4711, rewards: 0.0831\n",
      "Iteration: 3220, Train loss: -0.3707, rewards: 0.0587\n",
      "Iteration: 3230, Train loss: -0.3287, rewards: 0.0625\n",
      "Iteration: 3240, Train loss: -0.4824, rewards: 0.0750\n",
      "Iteration: 3250, Train loss: -0.6374, rewards: 0.0975\n",
      "Iteration: 3260, Train loss: -0.3068, rewards: 0.0506\n",
      "Iteration: 3270, Train loss: -0.3193, rewards: 0.0375\n",
      "Iteration: 3280, Train loss: -0.3966, rewards: 0.0850\n",
      "Iteration: 3290, Train loss: -0.3522, rewards: 0.0456\n",
      "Iteration: 3300, Train loss: -0.3269, rewards: 0.0369\n",
      "Eval:\n",
      "Hits@1: 0.0326, Hits@3: 0.0648, Hits@10: 0.1388, MRR: 0.0668\n",
      "------------------------------------------------------------\n",
      "Iteration: 3310, Train loss: -0.3143, rewards: 0.0475\n",
      "Iteration: 3320, Train loss: -0.4166, rewards: 0.0794\n",
      "Iteration: 3330, Train loss: -0.3944, rewards: 0.0638\n",
      "Iteration: 3340, Train loss: -0.3920, rewards: 0.0850\n",
      "Iteration: 3350, Train loss: -0.3827, rewards: 0.0663\n",
      "Iteration: 3360, Train loss: -0.1930, rewards: 0.0219\n",
      "Iteration: 3370, Train loss: -0.3348, rewards: 0.0731\n",
      "Iteration: 3380, Train loss: -0.5010, rewards: 0.0569\n",
      "Iteration: 3390, Train loss: -0.3530, rewards: 0.0456\n",
      "Iteration: 3400, Train loss: -0.2727, rewards: 0.0488\n",
      "Eval:\n",
      "Hits@1: 0.0336, Hits@3: 0.0685, Hits@10: 0.1342, MRR: 0.0649\n",
      "------------------------------------------------------------\n",
      "Iteration: 3410, Train loss: -0.2577, rewards: 0.0244\n",
      "Iteration: 3420, Train loss: -0.5736, rewards: 0.0850\n",
      "Iteration: 3430, Train loss: -0.2717, rewards: 0.0594\n",
      "Iteration: 3440, Train loss: -0.2608, rewards: 0.0394\n",
      "Iteration: 3450, Train loss: -0.4825, rewards: 0.0656\n",
      "Iteration: 3460, Train loss: -0.4040, rewards: 0.0681\n",
      "Iteration: 3470, Train loss: -0.3570, rewards: 0.0587\n",
      "Iteration: 3480, Train loss: -0.3426, rewards: 0.0838\n",
      "Iteration: 3490, Train loss: -0.2663, rewards: 0.0394\n",
      "Iteration: 3500, Train loss: -0.2924, rewards: 0.0413\n",
      "Eval:\n",
      "Hits@1: 0.0425, Hits@3: 0.0934, Hits@10: 0.1794, MRR: 0.0844\n",
      "------------------------------------------------------------\n",
      "Iteration: 3510, Train loss: -0.4104, rewards: 0.0644\n",
      "Iteration: 3520, Train loss: -0.3868, rewards: 0.0587\n",
      "Iteration: 3530, Train loss: -0.3787, rewards: 0.0606\n",
      "Iteration: 3540, Train loss: -0.3191, rewards: 0.0494\n",
      "Iteration: 3550, Train loss: -0.3961, rewards: 0.0544\n",
      "Iteration: 3560, Train loss: -0.4352, rewards: 0.0737\n",
      "Iteration: 3570, Train loss: -0.4647, rewards: 0.1006\n",
      "Iteration: 3580, Train loss: -0.4156, rewards: 0.0869\n",
      "Iteration: 3590, Train loss: -0.3978, rewards: 0.0788\n",
      "Iteration: 3600, Train loss: -0.3840, rewards: 0.0762\n",
      "Eval:\n",
      "Hits@1: 0.0345, Hits@3: 0.0907, Hits@10: 0.1891, MRR: 0.0801\n",
      "------------------------------------------------------------\n",
      "Iteration: 3610, Train loss: -0.2778, rewards: 0.0725\n",
      "Iteration: 3620, Train loss: -0.2635, rewards: 0.0612\n",
      "Iteration: 3630, Train loss: -0.3392, rewards: 0.0663\n",
      "Iteration: 3640, Train loss: -0.3176, rewards: 0.0569\n",
      "Iteration: 3650, Train loss: -0.3461, rewards: 0.0512\n",
      "Iteration: 3660, Train loss: -0.3002, rewards: 0.0631\n",
      "Iteration: 3670, Train loss: -0.3580, rewards: 0.0750\n",
      "Iteration: 3680, Train loss: -0.2128, rewards: 0.0331\n",
      "Iteration: 3690, Train loss: -0.2480, rewards: 0.0231\n",
      "Iteration: 3700, Train loss: -0.3288, rewards: 0.1106\n",
      "Eval:\n",
      "Hits@1: 0.0435, Hits@3: 0.1066, Hits@10: 0.2004, MRR: 0.0892\n",
      "------------------------------------------------------------\n",
      "Iteration: 3710, Train loss: -0.4078, rewards: 0.1081\n",
      "Iteration: 3720, Train loss: -0.2532, rewards: 0.0788\n",
      "Iteration: 3730, Train loss: -0.3482, rewards: 0.0550\n",
      "Iteration: 3740, Train loss: -0.4050, rewards: 0.0469\n",
      "Iteration: 3750, Train loss: -0.4201, rewards: 0.0781\n",
      "Iteration: 3760, Train loss: -0.3823, rewards: 0.0912\n",
      "Iteration: 3770, Train loss: -0.3258, rewards: 0.0437\n",
      "Iteration: 3780, Train loss: -0.4096, rewards: 0.0675\n",
      "Iteration: 3790, Train loss: -0.5112, rewards: 0.1450\n",
      "Iteration: 3800, Train loss: -0.2573, rewards: 0.0400\n",
      "Eval:\n",
      "Hits@1: 0.0450, Hits@3: 0.0940, Hits@10: 0.1703, MRR: 0.0832\n",
      "------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 3810, Train loss: -0.2758, rewards: 0.0394\n",
      "Iteration: 3820, Train loss: -0.3678, rewards: 0.0656\n",
      "Iteration: 3830, Train loss: -0.3826, rewards: 0.0744\n",
      "Iteration: 3840, Train loss: -0.4917, rewards: 0.0688\n",
      "Iteration: 3850, Train loss: -0.2001, rewards: 0.0275\n",
      "Iteration: 3860, Train loss: -0.3355, rewards: 0.0512\n",
      "Iteration: 3870, Train loss: -0.4075, rewards: 0.0906\n",
      "Iteration: 3880, Train loss: -0.2803, rewards: 0.0537\n",
      "Iteration: 3890, Train loss: -0.2219, rewards: 0.0550\n",
      "Iteration: 3900, Train loss: -0.2280, rewards: 0.0744\n",
      "Eval:\n",
      "Hits@1: 0.0616, Hits@3: 0.1119, Hits@10: 0.1893, MRR: 0.1002\n",
      "------------------------------------------------------------\n",
      "Iteration: 3910, Train loss: -0.2378, rewards: 0.0400\n",
      "Iteration: 3920, Train loss: -0.2824, rewards: 0.0688\n",
      "Iteration: 3930, Train loss: -0.3046, rewards: 0.0712\n",
      "Iteration: 3940, Train loss: -0.3246, rewards: 0.0700\n",
      "Iteration: 3950, Train loss: -0.4565, rewards: 0.0644\n",
      "Iteration: 3960, Train loss: -0.4210, rewards: 0.1062\n",
      "Iteration: 3970, Train loss: -0.3638, rewards: 0.0712\n",
      "Iteration: 3980, Train loss: -0.3375, rewards: 0.0406\n",
      "Iteration: 3990, Train loss: -0.3052, rewards: 0.0400\n",
      "Iteration: 4000, Train loss: -0.3617, rewards: 0.0494\n",
      "Eval:\n",
      "Hits@1: 0.0556, Hits@3: 0.1074, Hits@10: 0.1971, MRR: 0.0980\n",
      "------------------------------------------------------------\n",
      "Iteration: 4010, Train loss: -0.3156, rewards: 0.0425\n",
      "Iteration: 4020, Train loss: -0.3801, rewards: 0.0650\n",
      "Iteration: 4030, Train loss: -0.3504, rewards: 0.0619\n",
      "Iteration: 4040, Train loss: -0.3544, rewards: 0.0531\n",
      "Iteration: 4050, Train loss: -0.4926, rewards: 0.0744\n",
      "Iteration: 4060, Train loss: -0.4727, rewards: 0.0963\n",
      "Iteration: 4070, Train loss: -0.4498, rewards: 0.0731\n",
      "Iteration: 4080, Train loss: -0.4820, rewards: 0.1037\n",
      "Iteration: 4090, Train loss: -0.3753, rewards: 0.0419\n",
      "Iteration: 4100, Train loss: -0.4054, rewards: 0.0756\n",
      "Eval:\n",
      "Hits@1: 0.0527, Hits@3: 0.1110, Hits@10: 0.1957, MRR: 0.0960\n",
      "------------------------------------------------------------\n",
      "Iteration: 4110, Train loss: -0.4670, rewards: 0.0719\n",
      "Iteration: 4120, Train loss: -0.5409, rewards: 0.0800\n",
      "Iteration: 4130, Train loss: -0.3116, rewards: 0.0244\n",
      "Iteration: 4140, Train loss: -0.4680, rewards: 0.1175\n",
      "Iteration: 4150, Train loss: -0.3754, rewards: 0.0387\n",
      "Iteration: 4160, Train loss: -0.3693, rewards: 0.0494\n",
      "Iteration: 4170, Train loss: -0.4036, rewards: 0.0769\n",
      "Iteration: 4180, Train loss: -0.4125, rewards: 0.0919\n",
      "Iteration: 4190, Train loss: -0.4425, rewards: 0.0950\n",
      "Iteration: 4200, Train loss: -0.3631, rewards: 0.0938\n",
      "Eval:\n",
      "Hits@1: 0.0565, Hits@3: 0.1234, Hits@10: 0.2406, MRR: 0.1110\n",
      "------------------------------------------------------------\n",
      "Iteration: 4210, Train loss: -0.2985, rewards: 0.0681\n",
      "Iteration: 4220, Train loss: -0.4127, rewards: 0.0869\n",
      "Iteration: 4230, Train loss: -0.3208, rewards: 0.0663\n",
      "Iteration: 4240, Train loss: -0.4758, rewards: 0.0775\n",
      "Iteration: 4250, Train loss: -0.5223, rewards: 0.1069\n",
      "Iteration: 4260, Train loss: -0.4411, rewards: 0.0700\n",
      "Iteration: 4270, Train loss: -0.4758, rewards: 0.1031\n",
      "Iteration: 4280, Train loss: -0.2904, rewards: 0.0619\n",
      "Iteration: 4290, Train loss: -0.4658, rewards: 0.0894\n",
      "Iteration: 4300, Train loss: -0.3021, rewards: 0.0394\n",
      "Eval:\n",
      "Hits@1: 0.0586, Hits@3: 0.1167, Hits@10: 0.2147, MRR: 0.1054\n",
      "------------------------------------------------------------\n",
      "Iteration: 4310, Train loss: -0.3921, rewards: 0.0850\n",
      "Iteration: 4320, Train loss: -0.2799, rewards: 0.0656\n",
      "Iteration: 4330, Train loss: -0.3287, rewards: 0.0794\n",
      "Iteration: 4340, Train loss: -0.2825, rewards: 0.0875\n",
      "Iteration: 4350, Train loss: -0.2957, rewards: 0.0500\n",
      "Iteration: 4360, Train loss: -0.4713, rewards: 0.1031\n",
      "Iteration: 4370, Train loss: -0.5436, rewards: 0.0869\n",
      "Iteration: 4380, Train loss: -0.5768, rewards: 0.0825\n",
      "Iteration: 4390, Train loss: -0.3240, rewards: 0.0437\n",
      "Iteration: 4400, Train loss: -0.4258, rewards: 0.0750\n",
      "Eval:\n",
      "Hits@1: 0.0582, Hits@3: 0.1232, Hits@10: 0.2356, MRR: 0.1096\n",
      "------------------------------------------------------------\n",
      "Iteration: 4410, Train loss: -0.4408, rewards: 0.0994\n",
      "Iteration: 4420, Train loss: -0.4547, rewards: 0.1175\n",
      "Iteration: 4430, Train loss: -0.4473, rewards: 0.0794\n",
      "Iteration: 4440, Train loss: -0.4655, rewards: 0.1075\n",
      "Iteration: 4450, Train loss: -0.2996, rewards: 0.0481\n",
      "Iteration: 4460, Train loss: -0.4010, rewards: 0.0925\n",
      "Iteration: 4470, Train loss: -0.3681, rewards: 0.0881\n",
      "Iteration: 4480, Train loss: -0.5139, rewards: 0.1169\n",
      "Iteration: 4490, Train loss: -0.3975, rewards: 0.0563\n",
      "Iteration: 4500, Train loss: -0.3849, rewards: 0.1175\n",
      "Eval:\n",
      "Hits@1: 0.0447, Hits@3: 0.1107, Hits@10: 0.2068, MRR: 0.0924\n",
      "------------------------------------------------------------\n",
      "Iteration: 4510, Train loss: -0.2876, rewards: 0.0425\n",
      "Iteration: 4520, Train loss: -0.4693, rewards: 0.0881\n",
      "Iteration: 4530, Train loss: -0.4014, rewards: 0.0838\n",
      "Iteration: 4540, Train loss: -0.4582, rewards: 0.0756\n",
      "Iteration: 4550, Train loss: -0.3973, rewards: 0.0931\n",
      "Iteration: 4560, Train loss: -0.4170, rewards: 0.0712\n",
      "Iteration: 4570, Train loss: -0.3400, rewards: 0.0594\n",
      "Iteration: 4580, Train loss: -0.4376, rewards: 0.0731\n",
      "Iteration: 4590, Train loss: -0.4401, rewards: 0.0675\n",
      "Iteration: 4600, Train loss: -0.3667, rewards: 0.0925\n",
      "Eval:\n",
      "Hits@1: 0.0742, Hits@3: 0.1422, Hits@10: 0.2339, MRR: 0.1238\n",
      "------------------------------------------------------------\n",
      "Iteration: 4610, Train loss: -0.4959, rewards: 0.0963\n",
      "Iteration: 4620, Train loss: -0.4516, rewards: 0.1056\n",
      "Iteration: 4630, Train loss: -0.3101, rewards: 0.0625\n",
      "Iteration: 4640, Train loss: -0.3112, rewards: 0.0900\n",
      "Iteration: 4650, Train loss: -0.3105, rewards: 0.0575\n",
      "Iteration: 4660, Train loss: -0.3586, rewards: 0.0737\n",
      "Iteration: 4670, Train loss: -0.3182, rewards: 0.0419\n",
      "Iteration: 4680, Train loss: -0.3548, rewards: 0.0862\n",
      "Iteration: 4690, Train loss: -0.3162, rewards: 0.0606\n",
      "Iteration: 4700, Train loss: -0.5729, rewards: 0.1494\n",
      "Eval:\n",
      "Hits@1: 0.0623, Hits@3: 0.1337, Hits@10: 0.2182, MRR: 0.1111\n",
      "------------------------------------------------------------\n",
      "Iteration: 4710, Train loss: -0.3333, rewards: 0.0813\n",
      "Iteration: 4720, Train loss: -0.4049, rewards: 0.0675\n",
      "Iteration: 4730, Train loss: -0.3692, rewards: 0.0606\n",
      "Iteration: 4740, Train loss: -0.3267, rewards: 0.0600\n",
      "Iteration: 4750, Train loss: -0.3833, rewards: 0.0700\n",
      "Iteration: 4760, Train loss: -0.3093, rewards: 0.0900\n",
      "Iteration: 4770, Train loss: -0.4590, rewards: 0.1000\n",
      "Iteration: 4780, Train loss: -0.5777, rewards: 0.1106\n",
      "Iteration: 4790, Train loss: -0.3473, rewards: 0.0800\n",
      "Iteration: 4800, Train loss: -0.2515, rewards: 0.0488\n",
      "Eval:\n",
      "Hits@1: 0.0474, Hits@3: 0.0933, Hits@10: 0.1648, MRR: 0.0835\n",
      "------------------------------------------------------------\n",
      "Iteration: 4810, Train loss: -0.2202, rewards: 0.0712\n",
      "Iteration: 4820, Train loss: -0.2912, rewards: 0.0638\n",
      "Iteration: 4830, Train loss: -0.2918, rewards: 0.0419\n",
      "Iteration: 4840, Train loss: -0.3537, rewards: 0.0481\n",
      "Iteration: 4850, Train loss: -0.3150, rewards: 0.0350\n",
      "Iteration: 4860, Train loss: -0.3542, rewards: 0.0619\n",
      "Iteration: 4870, Train loss: -0.3721, rewards: 0.0650\n",
      "Iteration: 4880, Train loss: -0.3934, rewards: 0.0769\n",
      "Iteration: 4890, Train loss: -0.3730, rewards: 0.0587\n",
      "Iteration: 4900, Train loss: -0.4027, rewards: 0.0862\n",
      "Eval:\n",
      "Hits@1: 0.0556, Hits@3: 0.1034, Hits@10: 0.1887, MRR: 0.0959\n",
      "------------------------------------------------------------\n",
      "Iteration: 4910, Train loss: -0.3462, rewards: 0.0563\n",
      "Iteration: 4920, Train loss: -0.3872, rewards: 0.0756\n",
      "Iteration: 4930, Train loss: -0.4311, rewards: 0.0900\n",
      "Iteration: 4940, Train loss: -0.3950, rewards: 0.0919\n",
      "Iteration: 4950, Train loss: -0.3942, rewards: 0.0844\n",
      "Iteration: 4960, Train loss: -0.4438, rewards: 0.0944\n",
      "Iteration: 4970, Train loss: -0.3226, rewards: 0.0612\n",
      "Iteration: 4980, Train loss: -0.4820, rewards: 0.0919\n",
      "Iteration: 4990, Train loss: -0.4527, rewards: 0.0737\n",
      "Iteration: 5000, Train loss: -0.3536, rewards: 0.0612\n",
      "Eval:\n",
      "Hits@1: 0.0567, Hits@3: 0.1100, Hits@10: 0.1955, MRR: 0.1008\n",
      "------------------------------------------------------------\n",
      "Iteration: 5010, Train loss: -0.3067, rewards: 0.0719\n",
      "Iteration: 5020, Train loss: -0.3443, rewards: 0.0663\n",
      "Iteration: 5030, Train loss: -0.3564, rewards: 0.0587\n",
      "Iteration: 5040, Train loss: -0.3251, rewards: 0.0769\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 5050, Train loss: -0.3794, rewards: 0.0663\n",
      "Iteration: 5060, Train loss: -0.3440, rewards: 0.0769\n",
      "Iteration: 5070, Train loss: -0.4044, rewards: 0.1025\n",
      "Iteration: 5080, Train loss: -0.2735, rewards: 0.0319\n",
      "Iteration: 5090, Train loss: -0.4322, rewards: 0.0688\n",
      "Iteration: 5100, Train loss: -0.3401, rewards: 0.0563\n",
      "Eval:\n",
      "Hits@1: 0.0761, Hits@3: 0.1547, Hits@10: 0.2576, MRR: 0.1338\n",
      "------------------------------------------------------------\n",
      "Iteration: 5110, Train loss: -0.4582, rewards: 0.1306\n",
      "Iteration: 5120, Train loss: -0.3959, rewards: 0.0819\n",
      "Iteration: 5130, Train loss: -0.5937, rewards: 0.0963\n",
      "Iteration: 5140, Train loss: -0.4884, rewards: 0.1006\n",
      "Iteration: 5150, Train loss: -0.4860, rewards: 0.0794\n",
      "Iteration: 5160, Train loss: -0.4957, rewards: 0.1056\n",
      "Iteration: 5170, Train loss: -0.5132, rewards: 0.1050\n",
      "Iteration: 5180, Train loss: -0.5115, rewards: 0.1113\n",
      "Iteration: 5190, Train loss: -0.3158, rewards: 0.0456\n",
      "Iteration: 5200, Train loss: -0.3367, rewards: 0.0481\n",
      "Eval:\n",
      "Hits@1: 0.0551, Hits@3: 0.1268, Hits@10: 0.2105, MRR: 0.1074\n",
      "------------------------------------------------------------\n",
      "Iteration: 5210, Train loss: -0.3999, rewards: 0.0900\n",
      "Iteration: 5220, Train loss: -0.5173, rewards: 0.0956\n",
      "Iteration: 5230, Train loss: -0.4591, rewards: 0.1225\n",
      "Iteration: 5240, Train loss: -0.3759, rewards: 0.0688\n",
      "Iteration: 5250, Train loss: -0.5434, rewards: 0.1069\n",
      "Iteration: 5260, Train loss: -0.3755, rewards: 0.0813\n",
      "Iteration: 5270, Train loss: -0.4105, rewards: 0.0981\n",
      "Iteration: 5280, Train loss: -0.4291, rewards: 0.0819\n",
      "Iteration: 5290, Train loss: -0.5246, rewards: 0.1119\n",
      "Iteration: 5300, Train loss: -0.3188, rewards: 0.0794\n",
      "Eval:\n",
      "Hits@1: 0.0883, Hits@3: 0.1596, Hits@10: 0.2461, MRR: 0.1387\n",
      "------------------------------------------------------------\n",
      "Iteration: 5310, Train loss: -0.4161, rewards: 0.1281\n",
      "Iteration: 5320, Train loss: -0.5996, rewards: 0.1275\n",
      "Iteration: 5330, Train loss: -0.5042, rewards: 0.0844\n",
      "Iteration: 5340, Train loss: -0.3568, rewards: 0.1006\n",
      "Iteration: 5350, Train loss: -0.3354, rewards: 0.0800\n",
      "Iteration: 5360, Train loss: -0.4136, rewards: 0.0813\n",
      "Iteration: 5370, Train loss: -0.3318, rewards: 0.0750\n",
      "Iteration: 5380, Train loss: -0.5035, rewards: 0.1175\n",
      "Iteration: 5390, Train loss: -0.2993, rewards: 0.0556\n",
      "Iteration: 5400, Train loss: -0.1883, rewards: 0.0850\n",
      "Eval:\n",
      "Hits@1: 0.0467, Hits@3: 0.1044, Hits@10: 0.2175, MRR: 0.0948\n",
      "------------------------------------------------------------\n",
      "Iteration: 5410, Train loss: -0.2430, rewards: 0.0788\n",
      "Iteration: 5420, Train loss: -0.3266, rewards: 0.1019\n",
      "Iteration: 5430, Train loss: -0.3483, rewards: 0.1281\n",
      "Iteration: 5440, Train loss: -0.3219, rewards: 0.1013\n",
      "Iteration: 5450, Train loss: -0.2788, rewards: 0.0675\n",
      "Iteration: 5460, Train loss: -0.3923, rewards: 0.0919\n",
      "Iteration: 5470, Train loss: -0.3266, rewards: 0.0612\n",
      "Iteration: 5480, Train loss: -0.3364, rewards: 0.0944\n",
      "Iteration: 5490, Train loss: -0.4415, rewards: 0.1100\n",
      "Iteration: 5500, Train loss: -0.4592, rewards: 0.1025\n",
      "Eval:\n",
      "Hits@1: 0.1120, Hits@3: 0.2012, Hits@10: 0.3110, MRR: 0.1752\n",
      "------------------------------------------------------------\n",
      "Iteration: 5510, Train loss: -0.3818, rewards: 0.0900\n",
      "Iteration: 5520, Train loss: -0.5061, rewards: 0.1462\n",
      "Iteration: 5530, Train loss: -0.3618, rewards: 0.0931\n",
      "Iteration: 5540, Train loss: -0.3838, rewards: 0.0681\n",
      "Iteration: 5550, Train loss: -0.3900, rewards: 0.0719\n",
      "Iteration: 5560, Train loss: -0.4886, rewards: 0.1156\n",
      "Iteration: 5570, Train loss: -0.4819, rewards: 0.1075\n",
      "Iteration: 5580, Train loss: -0.5058, rewards: 0.1150\n",
      "Iteration: 5590, Train loss: -0.5585, rewards: 0.1031\n",
      "Iteration: 5600, Train loss: -0.5770, rewards: 0.0919\n",
      "Eval:\n",
      "Hits@1: 0.0935, Hits@3: 0.1542, Hits@10: 0.2476, MRR: 0.1419\n",
      "------------------------------------------------------------\n",
      "Iteration: 5610, Train loss: -0.4750, rewards: 0.1019\n",
      "Iteration: 5620, Train loss: -0.6467, rewards: 0.1681\n",
      "Iteration: 5630, Train loss: -0.5350, rewards: 0.1406\n",
      "Iteration: 5640, Train loss: -0.3014, rewards: 0.0938\n",
      "Iteration: 5650, Train loss: -0.4074, rewards: 0.1013\n",
      "Iteration: 5660, Train loss: -0.5041, rewards: 0.1400\n",
      "Iteration: 5670, Train loss: -0.4182, rewards: 0.0919\n",
      "Iteration: 5680, Train loss: -0.3479, rewards: 0.0744\n",
      "Iteration: 5690, Train loss: -0.4400, rewards: 0.0881\n",
      "Iteration: 5700, Train loss: -0.3509, rewards: 0.0750\n",
      "Eval:\n",
      "Hits@1: 0.1017, Hits@3: 0.1933, Hits@10: 0.3106, MRR: 0.1674\n",
      "------------------------------------------------------------\n",
      "Iteration: 5710, Train loss: -0.4905, rewards: 0.1663\n",
      "Iteration: 5720, Train loss: -0.5288, rewards: 0.0938\n",
      "Iteration: 5730, Train loss: -0.4986, rewards: 0.0712\n",
      "Iteration: 5740, Train loss: -0.3347, rewards: 0.0631\n",
      "Iteration: 5750, Train loss: -0.3180, rewards: 0.0519\n",
      "Iteration: 5760, Train loss: -0.4795, rewards: 0.1325\n",
      "Iteration: 5770, Train loss: -0.4259, rewards: 0.0963\n",
      "Iteration: 5780, Train loss: -0.4225, rewards: 0.0706\n",
      "Iteration: 5790, Train loss: -0.4293, rewards: 0.1163\n",
      "Iteration: 5800, Train loss: -0.3933, rewards: 0.0894\n",
      "Eval:\n",
      "Hits@1: 0.0936, Hits@3: 0.1680, Hits@10: 0.2677, MRR: 0.1496\n",
      "------------------------------------------------------------\n",
      "Iteration: 5810, Train loss: -0.4784, rewards: 0.1369\n",
      "Iteration: 5820, Train loss: -0.3749, rewards: 0.0963\n",
      "Iteration: 5830, Train loss: -0.4709, rewards: 0.0862\n",
      "Iteration: 5840, Train loss: -0.5611, rewards: 0.1369\n",
      "Iteration: 5850, Train loss: -0.5436, rewards: 0.1837\n",
      "Iteration: 5860, Train loss: -0.4249, rewards: 0.1094\n",
      "Iteration: 5870, Train loss: -0.5375, rewards: 0.0956\n",
      "Iteration: 5880, Train loss: -0.4695, rewards: 0.1388\n",
      "Iteration: 5890, Train loss: -0.5440, rewards: 0.1100\n",
      "Iteration: 5900, Train loss: -0.4584, rewards: 0.1231\n",
      "Eval:\n",
      "Hits@1: 0.0925, Hits@3: 0.1754, Hits@10: 0.2670, MRR: 0.1501\n",
      "------------------------------------------------------------\n",
      "Iteration: 5910, Train loss: -0.5279, rewards: 0.1200\n",
      "Iteration: 5920, Train loss: -0.4889, rewards: 0.1100\n",
      "Iteration: 5930, Train loss: -0.5968, rewards: 0.1638\n",
      "Iteration: 5940, Train loss: -0.3115, rewards: 0.0688\n",
      "Iteration: 5950, Train loss: -0.3444, rewards: 0.0706\n",
      "Iteration: 5960, Train loss: -0.4659, rewards: 0.1212\n",
      "Iteration: 5970, Train loss: -0.4463, rewards: 0.1338\n",
      "Iteration: 5980, Train loss: -0.5306, rewards: 0.1113\n",
      "Iteration: 5990, Train loss: -0.4387, rewards: 0.1269\n",
      "Iteration: 6000, Train loss: -0.4824, rewards: 0.1000\n",
      "Eval:\n",
      "Hits@1: 0.1188, Hits@3: 0.2183, Hits@10: 0.3361, MRR: 0.1888\n",
      "------------------------------------------------------------\n",
      "Iteration: 6010, Train loss: -0.5189, rewards: 0.1369\n",
      "Iteration: 6020, Train loss: -0.3870, rewards: 0.0869\n",
      "Iteration: 6030, Train loss: -0.4629, rewards: 0.1094\n",
      "Iteration: 6040, Train loss: -0.4949, rewards: 0.1237\n",
      "Iteration: 6050, Train loss: -0.4882, rewards: 0.1363\n",
      "Iteration: 6060, Train loss: -0.4014, rewards: 0.0819\n",
      "Iteration: 6070, Train loss: -0.3813, rewards: 0.1056\n",
      "Iteration: 6080, Train loss: -0.3962, rewards: 0.0719\n",
      "Iteration: 6090, Train loss: -0.5721, rewards: 0.1094\n",
      "Iteration: 6100, Train loss: -0.5079, rewards: 0.1050\n",
      "Eval:\n",
      "Hits@1: 0.1121, Hits@3: 0.2004, Hits@10: 0.3041, MRR: 0.1741\n",
      "------------------------------------------------------------\n",
      "Iteration: 6110, Train loss: -0.6082, rewards: 0.1412\n",
      "Iteration: 6120, Train loss: -0.6201, rewards: 0.1081\n",
      "Iteration: 6130, Train loss: -0.7247, rewards: 0.1794\n",
      "Iteration: 6140, Train loss: -0.3889, rewards: 0.1263\n",
      "Iteration: 6150, Train loss: -0.3574, rewards: 0.0925\n",
      "Iteration: 6160, Train loss: -0.3727, rewards: 0.1394\n",
      "Iteration: 6170, Train loss: -0.5210, rewards: 0.1481\n",
      "Iteration: 6180, Train loss: -0.3743, rewards: 0.0994\n",
      "Iteration: 6190, Train loss: -0.4724, rewards: 0.1194\n",
      "Iteration: 6200, Train loss: -0.6155, rewards: 0.1356\n",
      "Eval:\n",
      "Hits@1: 0.1036, Hits@3: 0.1951, Hits@10: 0.3197, MRR: 0.1708\n",
      "------------------------------------------------------------\n",
      "Iteration: 6210, Train loss: -0.4509, rewards: 0.1037\n",
      "Iteration: 6220, Train loss: -0.3832, rewards: 0.0700\n",
      "Iteration: 6230, Train loss: -0.3282, rewards: 0.0650\n",
      "Iteration: 6240, Train loss: -0.3253, rewards: 0.0594\n",
      "Iteration: 6250, Train loss: -0.3221, rewards: 0.0475\n",
      "Iteration: 6260, Train loss: -0.4017, rewards: 0.0694\n",
      "Iteration: 6270, Train loss: -0.5158, rewards: 0.1431\n",
      "Iteration: 6280, Train loss: -0.4538, rewards: 0.0881\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 6290, Train loss: -0.4916, rewards: 0.1031\n",
      "Iteration: 6300, Train loss: -0.4750, rewards: 0.1056\n",
      "Eval:\n",
      "Hits@1: 0.0904, Hits@3: 0.1715, Hits@10: 0.2692, MRR: 0.1492\n",
      "------------------------------------------------------------\n",
      "Iteration: 6310, Train loss: -0.4072, rewards: 0.0569\n",
      "Iteration: 6320, Train loss: -0.4414, rewards: 0.0875\n",
      "Iteration: 6330, Train loss: -0.6210, rewards: 0.1294\n",
      "Iteration: 6340, Train loss: -0.5284, rewards: 0.1269\n",
      "Iteration: 6350, Train loss: -0.4664, rewards: 0.0750\n",
      "Iteration: 6360, Train loss: -0.5348, rewards: 0.0887\n",
      "Iteration: 6370, Train loss: -0.5371, rewards: 0.1019\n",
      "Iteration: 6380, Train loss: -0.7254, rewards: 0.1812\n",
      "Iteration: 6390, Train loss: -0.3619, rewards: 0.0956\n",
      "Iteration: 6400, Train loss: -0.4806, rewards: 0.1462\n",
      "Eval:\n",
      "Hits@1: 0.1208, Hits@3: 0.2284, Hits@10: 0.3438, MRR: 0.1943\n",
      "------------------------------------------------------------\n",
      "Iteration: 6410, Train loss: -0.5762, rewards: 0.1181\n",
      "Iteration: 6420, Train loss: -0.5777, rewards: 0.0869\n",
      "Iteration: 6430, Train loss: -0.6235, rewards: 0.1606\n",
      "Iteration: 6440, Train loss: -0.5738, rewards: 0.1013\n",
      "Iteration: 6450, Train loss: -0.6346, rewards: 0.1313\n",
      "Iteration: 6460, Train loss: -0.6015, rewards: 0.1231\n",
      "Iteration: 6470, Train loss: -0.4911, rewards: 0.1500\n",
      "Iteration: 6480, Train loss: -0.6246, rewards: 0.1013\n",
      "Iteration: 6490, Train loss: -0.4875, rewards: 0.1037\n",
      "Iteration: 6500, Train loss: -0.5857, rewards: 0.0856\n",
      "Eval:\n",
      "Hits@1: 0.0899, Hits@3: 0.1596, Hits@10: 0.2564, MRR: 0.1423\n",
      "------------------------------------------------------------\n",
      "Iteration: 6510, Train loss: -0.5236, rewards: 0.0881\n",
      "Iteration: 6520, Train loss: -0.4671, rewards: 0.0675\n",
      "Iteration: 6530, Train loss: -0.4977, rewards: 0.1256\n",
      "Iteration: 6540, Train loss: -0.4133, rewards: 0.1013\n",
      "Iteration: 6550, Train loss: -0.3863, rewards: 0.0794\n",
      "Iteration: 6560, Train loss: -0.3942, rewards: 0.0831\n",
      "Iteration: 6570, Train loss: -0.4700, rewards: 0.0988\n",
      "Iteration: 6580, Train loss: -0.4996, rewards: 0.1000\n",
      "Iteration: 6590, Train loss: -0.3579, rewards: 0.1206\n",
      "Iteration: 6600, Train loss: -0.3597, rewards: 0.0850\n",
      "Eval:\n",
      "Hits@1: 0.1324, Hits@3: 0.2394, Hits@10: 0.3623, MRR: 0.2068\n",
      "------------------------------------------------------------\n",
      "Iteration: 6610, Train loss: -0.4209, rewards: 0.1113\n",
      "Iteration: 6620, Train loss: -0.4285, rewards: 0.1138\n",
      "Iteration: 6630, Train loss: -0.5465, rewards: 0.0856\n",
      "Iteration: 6640, Train loss: -0.4740, rewards: 0.1006\n",
      "Iteration: 6650, Train loss: -0.5026, rewards: 0.1338\n",
      "Iteration: 6660, Train loss: -0.4955, rewards: 0.1037\n",
      "Iteration: 6670, Train loss: -0.5696, rewards: 0.1450\n",
      "Iteration: 6680, Train loss: -0.4868, rewards: 0.1319\n",
      "Iteration: 6690, Train loss: -0.4518, rewards: 0.0981\n",
      "Iteration: 6700, Train loss: -0.5927, rewards: 0.1138\n",
      "Eval:\n",
      "Hits@1: 0.1178, Hits@3: 0.2005, Hits@10: 0.3187, MRR: 0.1800\n",
      "------------------------------------------------------------\n",
      "Iteration: 6710, Train loss: -0.5558, rewards: 0.1300\n",
      "Iteration: 6720, Train loss: -0.5273, rewards: 0.1606\n",
      "Iteration: 6730, Train loss: -0.3557, rewards: 0.0619\n",
      "Iteration: 6740, Train loss: -0.3865, rewards: 0.0969\n",
      "Iteration: 6750, Train loss: -0.3891, rewards: 0.0788\n",
      "Iteration: 6760, Train loss: -0.6408, rewards: 0.1375\n",
      "Iteration: 6770, Train loss: -0.5498, rewards: 0.0975\n",
      "Iteration: 6780, Train loss: -0.4259, rewards: 0.1131\n",
      "Iteration: 6790, Train loss: -0.6064, rewards: 0.1550\n",
      "Iteration: 6800, Train loss: -0.4591, rewards: 0.1069\n",
      "Eval:\n",
      "Hits@1: 0.1190, Hits@3: 0.1984, Hits@10: 0.3098, MRR: 0.1783\n",
      "------------------------------------------------------------\n",
      "Iteration: 6810, Train loss: -0.4319, rewards: 0.1087\n",
      "Iteration: 6820, Train loss: -0.4288, rewards: 0.0944\n",
      "Iteration: 6830, Train loss: -0.4845, rewards: 0.0900\n",
      "Iteration: 6840, Train loss: -0.6513, rewards: 0.1581\n",
      "Iteration: 6850, Train loss: -0.5382, rewards: 0.1113\n",
      "Iteration: 6860, Train loss: -0.4397, rewards: 0.1094\n",
      "Iteration: 6870, Train loss: -0.6244, rewards: 0.1444\n",
      "Iteration: 6880, Train loss: -0.5305, rewards: 0.1156\n",
      "Iteration: 6890, Train loss: -0.5668, rewards: 0.1075\n",
      "Iteration: 6900, Train loss: -0.3858, rewards: 0.0656\n",
      "Eval:\n",
      "Hits@1: 0.1375, Hits@3: 0.2417, Hits@10: 0.3690, MRR: 0.2103\n",
      "------------------------------------------------------------\n",
      "Iteration: 6910, Train loss: -0.5507, rewards: 0.1212\n",
      "Iteration: 6920, Train loss: -0.3232, rewards: 0.0737\n",
      "Iteration: 6930, Train loss: -0.4151, rewards: 0.1056\n",
      "Iteration: 6940, Train loss: -0.5887, rewards: 0.1406\n",
      "Iteration: 6950, Train loss: -0.6163, rewards: 0.1581\n",
      "Iteration: 6960, Train loss: -0.4724, rewards: 0.0744\n",
      "Iteration: 6970, Train loss: -0.5010, rewards: 0.1019\n",
      "Iteration: 6980, Train loss: -0.5071, rewards: 0.1250\n",
      "Iteration: 6990, Train loss: -0.4577, rewards: 0.1631\n",
      "Iteration: 7000, Train loss: -0.4260, rewards: 0.1250\n",
      "Eval:\n",
      "Hits@1: 0.1351, Hits@3: 0.2303, Hits@10: 0.3479, MRR: 0.2034\n",
      "------------------------------------------------------------\n",
      "Iteration: 7010, Train loss: -0.5808, rewards: 0.1544\n",
      "Iteration: 7020, Train loss: -0.4061, rewards: 0.0825\n",
      "Iteration: 7030, Train loss: -0.5905, rewards: 0.1163\n",
      "Iteration: 7040, Train loss: -0.4908, rewards: 0.1150\n",
      "Iteration: 7050, Train loss: -0.6232, rewards: 0.1425\n",
      "Iteration: 7060, Train loss: -0.6493, rewards: 0.1638\n",
      "Iteration: 7070, Train loss: -0.5618, rewards: 0.1150\n",
      "Iteration: 7080, Train loss: -0.6320, rewards: 0.1206\n",
      "Iteration: 7090, Train loss: -0.4727, rewards: 0.1125\n",
      "Iteration: 7100, Train loss: -0.4422, rewards: 0.0825\n",
      "Eval:\n",
      "Hits@1: 0.1227, Hits@3: 0.2155, Hits@10: 0.3310, MRR: 0.1895\n",
      "------------------------------------------------------------\n",
      "Iteration: 7110, Train loss: -0.3872, rewards: 0.1250\n",
      "Iteration: 7120, Train loss: -0.5121, rewards: 0.1144\n",
      "Iteration: 7130, Train loss: -0.6164, rewards: 0.1437\n",
      "Iteration: 7140, Train loss: -0.6293, rewards: 0.1394\n",
      "Iteration: 7150, Train loss: -0.5529, rewards: 0.0994\n",
      "Iteration: 7160, Train loss: -0.4157, rewards: 0.0606\n",
      "Iteration: 7170, Train loss: -0.6351, rewards: 0.1156\n",
      "Iteration: 7180, Train loss: -0.4061, rewards: 0.0775\n",
      "Iteration: 7190, Train loss: -0.6826, rewards: 0.1106\n",
      "Iteration: 7200, Train loss: -0.5790, rewards: 0.0550\n",
      "Eval:\n",
      "Hits@1: 0.0788, Hits@3: 0.1222, Hits@10: 0.1982, MRR: 0.1171\n",
      "------------------------------------------------------------\n",
      "Iteration: 7210, Train loss: -0.5765, rewards: 0.0850\n",
      "Iteration: 7220, Train loss: -0.5717, rewards: 0.0644\n",
      "Iteration: 7230, Train loss: -0.5391, rewards: 0.0969\n",
      "Iteration: 7240, Train loss: -0.7099, rewards: 0.1113\n",
      "Iteration: 7250, Train loss: -0.5004, rewards: 0.0781\n",
      "Iteration: 7260, Train loss: -0.5428, rewards: 0.1375\n",
      "Iteration: 7270, Train loss: -0.6825, rewards: 0.1494\n",
      "Iteration: 7280, Train loss: -0.5966, rewards: 0.1319\n",
      "Iteration: 7290, Train loss: -0.6340, rewards: 0.1212\n",
      "Iteration: 7300, Train loss: -0.5568, rewards: 0.1206\n",
      "Eval:\n",
      "Hits@1: 0.1428, Hits@3: 0.2315, Hits@10: 0.3480, MRR: 0.2079\n",
      "------------------------------------------------------------\n",
      "Iteration: 7310, Train loss: -0.5208, rewards: 0.1256\n",
      "Iteration: 7320, Train loss: -0.5364, rewards: 0.1356\n",
      "Iteration: 7330, Train loss: -0.4073, rewards: 0.1075\n",
      "Iteration: 7340, Train loss: -0.7002, rewards: 0.1500\n",
      "Iteration: 7350, Train loss: -0.7023, rewards: 0.1237\n",
      "Iteration: 7360, Train loss: -0.5317, rewards: 0.0862\n",
      "Iteration: 7370, Train loss: -0.7140, rewards: 0.1163\n",
      "Iteration: 7380, Train loss: -0.5234, rewards: 0.1288\n",
      "Iteration: 7390, Train loss: -0.4697, rewards: 0.0988\n",
      "Iteration: 7400, Train loss: -0.6271, rewards: 0.1150\n",
      "Eval:\n",
      "Hits@1: 0.1426, Hits@3: 0.2307, Hits@10: 0.3519, MRR: 0.2079\n",
      "------------------------------------------------------------\n",
      "Iteration: 7410, Train loss: -0.4535, rewards: 0.0938\n",
      "Iteration: 7420, Train loss: -0.4695, rewards: 0.0894\n",
      "Iteration: 7430, Train loss: -0.4511, rewards: 0.1150\n",
      "Iteration: 7440, Train loss: -0.5833, rewards: 0.1269\n",
      "Iteration: 7450, Train loss: -0.5554, rewards: 0.1363\n",
      "Iteration: 7460, Train loss: -0.5987, rewards: 0.1762\n",
      "Iteration: 7470, Train loss: -0.5723, rewards: 0.1031\n",
      "Iteration: 7480, Train loss: -0.5057, rewards: 0.0963\n",
      "Iteration: 7490, Train loss: -0.4378, rewards: 0.1475\n",
      "Iteration: 7500, Train loss: -0.5561, rewards: 0.1313\n",
      "Eval:\n",
      "Hits@1: 0.1367, Hits@3: 0.2247, Hits@10: 0.3363, MRR: 0.2005\n",
      "------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 7510, Train loss: -0.5177, rewards: 0.1094\n",
      "Iteration: 7520, Train loss: -0.5469, rewards: 0.1256\n",
      "Iteration: 7530, Train loss: -0.6034, rewards: 0.1281\n",
      "Iteration: 7540, Train loss: -0.5161, rewards: 0.1119\n",
      "Iteration: 7550, Train loss: -0.6517, rewards: 0.1837\n",
      "Iteration: 7560, Train loss: -0.5414, rewards: 0.1094\n",
      "Iteration: 7570, Train loss: -0.5444, rewards: 0.1194\n",
      "Iteration: 7580, Train loss: -0.6975, rewards: 0.1394\n",
      "Iteration: 7590, Train loss: -0.3859, rewards: 0.0744\n",
      "Iteration: 7600, Train loss: -0.4781, rewards: 0.0900\n",
      "Eval:\n",
      "Hits@1: 0.1500, Hits@3: 0.2379, Hits@10: 0.3466, MRR: 0.2134\n",
      "------------------------------------------------------------\n",
      "Iteration: 7610, Train loss: -0.4954, rewards: 0.1256\n",
      "Iteration: 7620, Train loss: -0.4751, rewards: 0.0663\n",
      "Iteration: 7630, Train loss: -0.6167, rewards: 0.1444\n",
      "Iteration: 7640, Train loss: -0.5338, rewards: 0.0838\n",
      "Iteration: 7650, Train loss: -0.6978, rewards: 0.1237\n",
      "Iteration: 7660, Train loss: -0.4819, rewards: 0.1163\n",
      "Iteration: 7670, Train loss: -0.5324, rewards: 0.1500\n",
      "Iteration: 7680, Train loss: -0.5097, rewards: 0.1306\n",
      "Iteration: 7690, Train loss: -0.5489, rewards: 0.1469\n",
      "Iteration: 7700, Train loss: -0.5490, rewards: 0.1094\n",
      "Eval:\n",
      "Hits@1: 0.1595, Hits@3: 0.2533, Hits@10: 0.3681, MRR: 0.2261\n",
      "------------------------------------------------------------\n",
      "Iteration: 7710, Train loss: -0.4687, rewards: 0.1244\n",
      "Iteration: 7720, Train loss: -0.5328, rewards: 0.1212\n",
      "Iteration: 7730, Train loss: -0.5359, rewards: 0.1306\n",
      "Iteration: 7740, Train loss: -0.4489, rewards: 0.0988\n",
      "Iteration: 7750, Train loss: -0.6820, rewards: 0.1131\n",
      "Iteration: 7760, Train loss: -0.5833, rewards: 0.1331\n",
      "Iteration: 7770, Train loss: -0.4899, rewards: 0.1219\n",
      "Iteration: 7780, Train loss: -0.5221, rewards: 0.1519\n",
      "Iteration: 7790, Train loss: -0.5644, rewards: 0.1169\n",
      "Iteration: 7800, Train loss: -0.6063, rewards: 0.1237\n",
      "Eval:\n",
      "Hits@1: 0.1497, Hits@3: 0.2264, Hits@10: 0.3162, MRR: 0.2035\n",
      "------------------------------------------------------------\n",
      "Iteration: 7810, Train loss: -0.5017, rewards: 0.1325\n",
      "Iteration: 7820, Train loss: -0.6314, rewards: 0.1144\n",
      "Iteration: 7830, Train loss: -0.4221, rewards: 0.0800\n",
      "Iteration: 7840, Train loss: -0.4934, rewards: 0.1025\n",
      "Iteration: 7850, Train loss: -0.5638, rewards: 0.1319\n",
      "Iteration: 7860, Train loss: -0.6120, rewards: 0.1300\n",
      "Iteration: 7870, Train loss: -0.4457, rewards: 0.1556\n",
      "Iteration: 7880, Train loss: -0.5532, rewards: 0.1019\n",
      "Iteration: 7890, Train loss: -0.5819, rewards: 0.1425\n",
      "Iteration: 7900, Train loss: -0.6170, rewards: 0.1206\n",
      "Eval:\n",
      "Hits@1: 0.1529, Hits@3: 0.2382, Hits@10: 0.3394, MRR: 0.2128\n",
      "------------------------------------------------------------\n",
      "Iteration: 7910, Train loss: -0.4514, rewards: 0.0900\n",
      "Iteration: 7920, Train loss: -0.6350, rewards: 0.1506\n",
      "Iteration: 7930, Train loss: -0.5364, rewards: 0.1500\n",
      "Iteration: 7940, Train loss: -0.5428, rewards: 0.0981\n",
      "Iteration: 7950, Train loss: -0.5609, rewards: 0.1787\n",
      "Iteration: 7960, Train loss: -0.5178, rewards: 0.1062\n",
      "Iteration: 7970, Train loss: -0.4566, rewards: 0.1056\n",
      "Iteration: 7980, Train loss: -0.5025, rewards: 0.1037\n",
      "Iteration: 7990, Train loss: -0.5826, rewards: 0.1756\n",
      "Iteration: 8000, Train loss: -0.5692, rewards: 0.1044\n",
      "Eval:\n",
      "Hits@1: 0.1611, Hits@3: 0.2543, Hits@10: 0.3660, MRR: 0.2270\n",
      "------------------------------------------------------------\n",
      "Iteration: 8010, Train loss: -0.5383, rewards: 0.1169\n",
      "Iteration: 8020, Train loss: -0.4427, rewards: 0.1212\n",
      "Iteration: 8030, Train loss: -0.6648, rewards: 0.1737\n",
      "Iteration: 8040, Train loss: -0.4482, rewards: 0.0963\n",
      "Iteration: 8050, Train loss: -0.6634, rewards: 0.1537\n",
      "Iteration: 8060, Train loss: -0.5344, rewards: 0.1319\n",
      "Iteration: 8070, Train loss: -0.5718, rewards: 0.1163\n",
      "Iteration: 8080, Train loss: -0.7798, rewards: 0.1306\n",
      "Iteration: 8090, Train loss: -0.6615, rewards: 0.1200\n",
      "Iteration: 8100, Train loss: -0.7760, rewards: 0.1588\n",
      "Eval:\n",
      "Hits@1: 0.1649, Hits@3: 0.2601, Hits@10: 0.3670, MRR: 0.2309\n",
      "------------------------------------------------------------\n",
      "Iteration: 8110, Train loss: -0.6701, rewards: 0.1825\n",
      "Iteration: 8120, Train loss: -0.6074, rewards: 0.1175\n",
      "Iteration: 8130, Train loss: -0.5953, rewards: 0.1306\n",
      "Iteration: 8140, Train loss: -0.4832, rewards: 0.0938\n",
      "Iteration: 8150, Train loss: -0.3606, rewards: 0.0938\n",
      "Iteration: 8160, Train loss: -0.4055, rewards: 0.0794\n",
      "Iteration: 8170, Train loss: -0.7424, rewards: 0.1638\n",
      "Iteration: 8180, Train loss: -0.4516, rewards: 0.0862\n",
      "Iteration: 8190, Train loss: -0.6425, rewards: 0.1250\n",
      "Iteration: 8200, Train loss: -0.6121, rewards: 0.1244\n",
      "Eval:\n",
      "Hits@1: 0.1610, Hits@3: 0.2440, Hits@10: 0.3376, MRR: 0.2189\n",
      "------------------------------------------------------------\n",
      "Iteration: 8210, Train loss: -0.5187, rewards: 0.1444\n",
      "Iteration: 8220, Train loss: -0.5358, rewards: 0.1331\n",
      "Iteration: 8230, Train loss: -0.4287, rewards: 0.1206\n",
      "Iteration: 8240, Train loss: -0.5951, rewards: 0.1325\n",
      "Iteration: 8250, Train loss: -0.5817, rewards: 0.1363\n",
      "Iteration: 8260, Train loss: -0.7389, rewards: 0.1594\n",
      "Iteration: 8270, Train loss: -0.5937, rewards: 0.1419\n",
      "Iteration: 8280, Train loss: -0.7722, rewards: 0.1206\n",
      "Iteration: 8290, Train loss: -0.5800, rewards: 0.1425\n",
      "Iteration: 8300, Train loss: -0.6852, rewards: 0.1450\n",
      "Eval:\n",
      "Hits@1: 0.1580, Hits@3: 0.2438, Hits@10: 0.3438, MRR: 0.2180\n",
      "------------------------------------------------------------\n",
      "Iteration: 8310, Train loss: -0.6191, rewards: 0.1044\n",
      "Iteration: 8320, Train loss: -0.6960, rewards: 0.1281\n",
      "Iteration: 8330, Train loss: -0.6228, rewards: 0.1138\n",
      "Iteration: 8340, Train loss: -0.6516, rewards: 0.1519\n",
      "Iteration: 8350, Train loss: -0.6382, rewards: 0.2000\n",
      "Iteration: 8360, Train loss: -0.5119, rewards: 0.1019\n",
      "Iteration: 8370, Train loss: -0.5685, rewards: 0.1512\n",
      "Iteration: 8380, Train loss: -0.6755, rewards: 0.1313\n",
      "Iteration: 8390, Train loss: -0.6270, rewards: 0.1425\n",
      "Iteration: 8400, Train loss: -0.4184, rewards: 0.0931\n",
      "Eval:\n",
      "Hits@1: 0.1498, Hits@3: 0.2282, Hits@10: 0.3159, MRR: 0.2041\n",
      "------------------------------------------------------------\n",
      "Iteration: 8410, Train loss: -0.5979, rewards: 0.1875\n",
      "Iteration: 8420, Train loss: -0.7327, rewards: 0.1138\n",
      "Iteration: 8430, Train loss: -0.6909, rewards: 0.1638\n",
      "Iteration: 8440, Train loss: -0.5685, rewards: 0.1106\n",
      "Iteration: 8450, Train loss: -0.6257, rewards: 0.1294\n",
      "Iteration: 8460, Train loss: -0.4978, rewards: 0.1144\n",
      "Iteration: 8470, Train loss: -0.6706, rewards: 0.1388\n",
      "Iteration: 8480, Train loss: -0.5127, rewards: 0.1200\n",
      "Iteration: 8490, Train loss: -0.4702, rewards: 0.1412\n",
      "Iteration: 8500, Train loss: -0.4447, rewards: 0.1062\n",
      "Eval:\n",
      "Hits@1: 0.1658, Hits@3: 0.2571, Hits@10: 0.3645, MRR: 0.2296\n",
      "------------------------------------------------------------\n",
      "Iteration: 8510, Train loss: -0.5005, rewards: 0.0900\n",
      "Iteration: 8520, Train loss: -0.5840, rewards: 0.1031\n",
      "Iteration: 8530, Train loss: -0.6887, rewards: 0.1881\n",
      "Iteration: 8540, Train loss: -0.5482, rewards: 0.1144\n",
      "Iteration: 8550, Train loss: -0.7763, rewards: 0.1900\n",
      "Iteration: 8560, Train loss: -0.5287, rewards: 0.1344\n",
      "Iteration: 8570, Train loss: -0.5575, rewards: 0.1406\n",
      "Iteration: 8580, Train loss: -0.5791, rewards: 0.1531\n",
      "Iteration: 8590, Train loss: -0.4619, rewards: 0.1113\n",
      "Iteration: 8600, Train loss: -0.5820, rewards: 0.1431\n",
      "Eval:\n",
      "Hits@1: 0.1654, Hits@3: 0.2585, Hits@10: 0.3681, MRR: 0.2301\n",
      "------------------------------------------------------------\n",
      "Iteration: 8610, Train loss: -0.5273, rewards: 0.1175\n",
      "Iteration: 8620, Train loss: -0.5517, rewards: 0.1138\n",
      "Iteration: 8630, Train loss: -0.5326, rewards: 0.1263\n",
      "Iteration: 8640, Train loss: -0.5180, rewards: 0.1669\n",
      "Iteration: 8650, Train loss: -0.7706, rewards: 0.1588\n",
      "Iteration: 8660, Train loss: -0.5792, rewards: 0.1775\n",
      "Iteration: 8670, Train loss: -0.8151, rewards: 0.1787\n",
      "Iteration: 8680, Train loss: -0.7408, rewards: 0.1069\n",
      "Iteration: 8690, Train loss: -0.6089, rewards: 0.1031\n",
      "Iteration: 8700, Train loss: -0.8145, rewards: 0.1706\n",
      "Eval:\n",
      "Hits@1: 0.1551, Hits@3: 0.2613, Hits@10: 0.3694, MRR: 0.2260\n",
      "------------------------------------------------------------\n",
      "Iteration: 8710, Train loss: -0.6567, rewards: 0.1100\n",
      "Iteration: 8720, Train loss: -0.2552, rewards: 0.0681\n",
      "Iteration: 8730, Train loss: -0.4411, rewards: 0.1062\n",
      "Iteration: 8740, Train loss: -0.4440, rewards: 0.1187\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 8750, Train loss: -0.4010, rewards: 0.0688\n",
      "Iteration: 8760, Train loss: -0.4100, rewards: 0.1062\n",
      "Iteration: 8770, Train loss: -0.3856, rewards: 0.0988\n",
      "Iteration: 8780, Train loss: -0.4046, rewards: 0.0800\n",
      "Iteration: 8790, Train loss: -0.6096, rewards: 0.1169\n",
      "Iteration: 8800, Train loss: -0.6082, rewards: 0.1181\n",
      "Eval:\n",
      "Hits@1: 0.1421, Hits@3: 0.2612, Hits@10: 0.3637, MRR: 0.2181\n",
      "------------------------------------------------------------\n",
      "Iteration: 8810, Train loss: -0.5480, rewards: 0.1156\n",
      "Iteration: 8820, Train loss: -0.6498, rewards: 0.1594\n",
      "Iteration: 8830, Train loss: -0.5052, rewards: 0.1237\n",
      "Iteration: 8840, Train loss: -0.6852, rewards: 0.1613\n",
      "Iteration: 8850, Train loss: -0.6083, rewards: 0.1812\n",
      "Iteration: 8860, Train loss: -0.5588, rewards: 0.1150\n",
      "Iteration: 8870, Train loss: -0.4578, rewards: 0.0794\n",
      "Iteration: 8880, Train loss: -0.5281, rewards: 0.1175\n",
      "Iteration: 8890, Train loss: -0.5159, rewards: 0.1075\n",
      "Iteration: 8900, Train loss: -0.4776, rewards: 0.1325\n",
      "Eval:\n",
      "Hits@1: 0.1392, Hits@3: 0.2548, Hits@10: 0.3551, MRR: 0.2147\n",
      "------------------------------------------------------------\n",
      "Iteration: 8910, Train loss: -0.6694, rewards: 0.1681\n",
      "Iteration: 8920, Train loss: -0.5544, rewards: 0.1625\n",
      "Iteration: 8930, Train loss: -0.4936, rewards: 0.0881\n",
      "Iteration: 8940, Train loss: -0.3883, rewards: 0.0594\n",
      "Iteration: 8950, Train loss: -0.4535, rewards: 0.1319\n",
      "Iteration: 8960, Train loss: -0.6016, rewards: 0.1706\n",
      "Iteration: 8970, Train loss: -0.6092, rewards: 0.1469\n",
      "Iteration: 8980, Train loss: -0.5586, rewards: 0.1281\n",
      "Iteration: 8990, Train loss: -0.6979, rewards: 0.1419\n",
      "Iteration: 9000, Train loss: -0.7395, rewards: 0.1437\n",
      "Eval:\n",
      "Hits@1: 0.1547, Hits@3: 0.2497, Hits@10: 0.3544, MRR: 0.2210\n",
      "------------------------------------------------------------\n",
      "Iteration: 9010, Train loss: -0.5933, rewards: 0.1275\n",
      "Iteration: 9020, Train loss: -0.5651, rewards: 0.1437\n",
      "Iteration: 9030, Train loss: -0.7790, rewards: 0.1775\n",
      "Iteration: 9040, Train loss: -0.6134, rewards: 0.0912\n",
      "Iteration: 9050, Train loss: -0.7143, rewards: 0.1619\n",
      "Iteration: 9060, Train loss: -0.5877, rewards: 0.1369\n",
      "Iteration: 9070, Train loss: -0.4597, rewards: 0.1150\n",
      "Iteration: 9080, Train loss: -0.5867, rewards: 0.1344\n",
      "Iteration: 9090, Train loss: -0.5431, rewards: 0.1237\n",
      "Iteration: 9100, Train loss: -0.6907, rewards: 0.1500\n",
      "Eval:\n",
      "Hits@1: 0.1687, Hits@3: 0.2571, Hits@10: 0.3603, MRR: 0.2308\n",
      "------------------------------------------------------------\n",
      "Iteration: 9110, Train loss: -0.6831, rewards: 0.1487\n",
      "Iteration: 9120, Train loss: -0.4994, rewards: 0.1313\n",
      "Iteration: 9130, Train loss: -0.4687, rewards: 0.1037\n",
      "Iteration: 9140, Train loss: -0.6202, rewards: 0.1175\n",
      "Iteration: 9150, Train loss: -0.4993, rewards: 0.1275\n",
      "Iteration: 9160, Train loss: -0.5423, rewards: 0.1631\n",
      "Iteration: 9170, Train loss: -0.6611, rewards: 0.1156\n",
      "Iteration: 9180, Train loss: -0.6198, rewards: 0.1300\n",
      "Iteration: 9190, Train loss: -0.6436, rewards: 0.1831\n",
      "Iteration: 9200, Train loss: -0.5697, rewards: 0.1425\n",
      "Eval:\n",
      "Hits@1: 0.1707, Hits@3: 0.2636, Hits@10: 0.3664, MRR: 0.2343\n",
      "------------------------------------------------------------\n",
      "Iteration: 9210, Train loss: -0.7821, rewards: 0.1663\n",
      "Iteration: 9220, Train loss: -0.7918, rewards: 0.1725\n",
      "Iteration: 9230, Train loss: -0.6294, rewards: 0.1619\n",
      "Iteration: 9240, Train loss: -0.7747, rewards: 0.1981\n",
      "Iteration: 9250, Train loss: -0.5282, rewards: 0.1581\n",
      "Iteration: 9260, Train loss: -0.7536, rewards: 0.1487\n",
      "Iteration: 9270, Train loss: -0.8266, rewards: 0.1837\n",
      "Iteration: 9280, Train loss: -0.7002, rewards: 0.1800\n",
      "Iteration: 9290, Train loss: -0.6423, rewards: 0.2219\n",
      "Iteration: 9300, Train loss: -0.8582, rewards: 0.2194\n",
      "Eval:\n",
      "Hits@1: 0.1781, Hits@3: 0.2733, Hits@10: 0.3676, MRR: 0.2413\n",
      "------------------------------------------------------------\n",
      "Iteration: 9310, Train loss: -0.5947, rewards: 0.1562\n",
      "Iteration: 9320, Train loss: -0.6030, rewards: 0.1419\n",
      "Iteration: 9330, Train loss: -0.5167, rewards: 0.1812\n",
      "Iteration: 9340, Train loss: -0.5873, rewards: 0.1525\n",
      "Iteration: 9350, Train loss: -0.6960, rewards: 0.1688\n",
      "Iteration: 9360, Train loss: -0.7059, rewards: 0.1744\n",
      "Iteration: 9370, Train loss: -0.6064, rewards: 0.1394\n",
      "Iteration: 9380, Train loss: -0.7099, rewards: 0.1688\n",
      "Iteration: 9390, Train loss: -0.6715, rewards: 0.1281\n",
      "Iteration: 9400, Train loss: -0.6802, rewards: 0.1475\n",
      "Eval:\n",
      "Hits@1: 0.1713, Hits@3: 0.2660, Hits@10: 0.3645, MRR: 0.2350\n",
      "------------------------------------------------------------\n",
      "Iteration: 9410, Train loss: -0.5894, rewards: 0.1350\n",
      "Iteration: 9420, Train loss: -0.6587, rewards: 0.1381\n",
      "Iteration: 9430, Train loss: -0.3500, rewards: 0.0506\n",
      "Iteration: 9440, Train loss: -0.6424, rewards: 0.1531\n",
      "Iteration: 9450, Train loss: -0.4490, rewards: 0.1375\n",
      "Iteration: 9460, Train loss: -0.8097, rewards: 0.2412\n",
      "Iteration: 9470, Train loss: -0.5254, rewards: 0.1756\n",
      "Iteration: 9480, Train loss: -0.5101, rewards: 0.0969\n",
      "Iteration: 9490, Train loss: -0.7053, rewards: 0.1737\n",
      "Iteration: 9500, Train loss: -0.7923, rewards: 0.1619\n",
      "Eval:\n",
      "Hits@1: 0.1809, Hits@3: 0.2735, Hits@10: 0.3653, MRR: 0.2421\n",
      "------------------------------------------------------------\n",
      "Iteration: 9510, Train loss: -0.6264, rewards: 0.2000\n",
      "Iteration: 9520, Train loss: -0.9083, rewards: 0.2275\n",
      "Iteration: 9530, Train loss: -0.8553, rewards: 0.1956\n",
      "Iteration: 9540, Train loss: -0.6441, rewards: 0.1850\n",
      "Iteration: 9550, Train loss: -0.5578, rewards: 0.1431\n",
      "Iteration: 9560, Train loss: -0.5209, rewards: 0.0956\n",
      "Iteration: 9570, Train loss: -0.6013, rewards: 0.1106\n",
      "Iteration: 9580, Train loss: -0.5361, rewards: 0.1419\n",
      "Iteration: 9590, Train loss: -0.6267, rewards: 0.1275\n",
      "Iteration: 9600, Train loss: -0.7279, rewards: 0.1781\n",
      "Eval:\n",
      "Hits@1: 0.1836, Hits@3: 0.2661, Hits@10: 0.3579, MRR: 0.2405\n",
      "------------------------------------------------------------\n",
      "Iteration: 9610, Train loss: -0.6085, rewards: 0.1844\n",
      "Iteration: 9620, Train loss: -0.7220, rewards: 0.1625\n",
      "Iteration: 9630, Train loss: -0.5815, rewards: 0.1363\n",
      "Iteration: 9640, Train loss: -0.6406, rewards: 0.1037\n",
      "Iteration: 9650, Train loss: -0.6811, rewards: 0.1831\n",
      "Iteration: 9660, Train loss: -0.5971, rewards: 0.1588\n",
      "Iteration: 9670, Train loss: -0.6309, rewards: 0.1506\n",
      "Iteration: 9680, Train loss: -0.5697, rewards: 0.1194\n",
      "Iteration: 9690, Train loss: -0.5472, rewards: 0.1481\n",
      "Iteration: 9700, Train loss: -0.4799, rewards: 0.1306\n",
      "Eval:\n",
      "Hits@1: 0.1365, Hits@3: 0.2315, Hits@10: 0.3216, MRR: 0.1997\n",
      "------------------------------------------------------------\n",
      "Iteration: 9710, Train loss: -0.6094, rewards: 0.1663\n",
      "Iteration: 9720, Train loss: -0.4759, rewards: 0.1106\n",
      "Iteration: 9730, Train loss: -0.6224, rewards: 0.1862\n",
      "Iteration: 9740, Train loss: -0.4884, rewards: 0.1487\n",
      "Iteration: 9750, Train loss: -0.5629, rewards: 0.1737\n",
      "Iteration: 9760, Train loss: -0.4430, rewards: 0.0706\n",
      "Iteration: 9770, Train loss: -0.6656, rewards: 0.1581\n",
      "Iteration: 9780, Train loss: -0.6068, rewards: 0.1531\n",
      "Iteration: 9790, Train loss: -0.5913, rewards: 0.1412\n",
      "Iteration: 9800, Train loss: -0.3982, rewards: 0.0700\n",
      "Eval:\n",
      "Hits@1: 0.1313, Hits@3: 0.2080, Hits@10: 0.3039, MRR: 0.1857\n",
      "------------------------------------------------------------\n",
      "Iteration: 9810, Train loss: -0.6612, rewards: 0.1444\n",
      "Iteration: 9820, Train loss: -0.4709, rewards: 0.0844\n",
      "Iteration: 9830, Train loss: -0.7213, rewards: 0.1888\n",
      "Iteration: 9840, Train loss: -0.5154, rewards: 0.1225\n",
      "Iteration: 9850, Train loss: -0.5180, rewards: 0.0938\n",
      "Iteration: 9860, Train loss: -0.5775, rewards: 0.1338\n",
      "Iteration: 9870, Train loss: -0.6834, rewards: 0.2219\n",
      "Iteration: 9880, Train loss: -0.4908, rewards: 0.1200\n",
      "Iteration: 9890, Train loss: -0.5834, rewards: 0.1175\n",
      "Iteration: 9900, Train loss: -0.6146, rewards: 0.1094\n",
      "Eval:\n",
      "Hits@1: 0.1531, Hits@3: 0.2596, Hits@10: 0.3567, MRR: 0.2227\n",
      "------------------------------------------------------------\n",
      "Iteration: 9910, Train loss: -0.6279, rewards: 0.1731\n",
      "Iteration: 9920, Train loss: -0.7287, rewards: 0.1950\n",
      "Iteration: 9930, Train loss: -0.7654, rewards: 0.1569\n",
      "Iteration: 9940, Train loss: -0.5358, rewards: 0.1338\n",
      "Iteration: 9950, Train loss: -0.7053, rewards: 0.2119\n",
      "Iteration: 9960, Train loss: -0.6359, rewards: 0.1269\n",
      "Iteration: 9970, Train loss: -0.6513, rewards: 0.1494\n",
      "Iteration: 9980, Train loss: -0.7708, rewards: 0.1787\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 9990, Train loss: -0.6904, rewards: 0.1475\n",
      "Iteration: 10000, Train loss: -0.6512, rewards: 0.2106\n",
      "Eval:\n",
      "Hits@1: 0.1808, Hits@3: 0.2725, Hits@10: 0.3748, MRR: 0.2449\n",
      "------------------------------------------------------------\n",
      "Iteration: 10010, Train loss: -0.5775, rewards: 0.1163\n",
      "Iteration: 10020, Train loss: -0.5912, rewards: 0.1775\n",
      "Iteration: 10030, Train loss: -0.5958, rewards: 0.0725\n",
      "Iteration: 10040, Train loss: -0.5529, rewards: 0.1263\n",
      "Iteration: 10050, Train loss: -0.5632, rewards: 0.1619\n",
      "Iteration: 10060, Train loss: -0.4684, rewards: 0.1806\n",
      "Iteration: 10070, Train loss: -0.6982, rewards: 0.1475\n",
      "Iteration: 10080, Train loss: -0.7784, rewards: 0.1644\n",
      "Iteration: 10090, Train loss: -0.5122, rewards: 0.1244\n",
      "Iteration: 10100, Train loss: -0.5941, rewards: 0.1556\n",
      "Eval:\n",
      "Hits@1: 0.1648, Hits@3: 0.2564, Hits@10: 0.3603, MRR: 0.2286\n",
      "------------------------------------------------------------\n",
      "Iteration: 10110, Train loss: -0.5579, rewards: 0.1206\n",
      "Iteration: 10120, Train loss: -0.7776, rewards: 0.1762\n",
      "Iteration: 10130, Train loss: -0.8797, rewards: 0.1506\n",
      "Iteration: 10140, Train loss: -0.7066, rewards: 0.1619\n",
      "Iteration: 10150, Train loss: -0.5839, rewards: 0.1412\n",
      "Iteration: 10160, Train loss: -0.7266, rewards: 0.1494\n",
      "Iteration: 10170, Train loss: -0.6097, rewards: 0.1319\n",
      "Iteration: 10180, Train loss: -0.7297, rewards: 0.1781\n",
      "Iteration: 10190, Train loss: -0.7564, rewards: 0.1862\n",
      "Iteration: 10200, Train loss: -0.6769, rewards: 0.1388\n",
      "Eval:\n",
      "Hits@1: 0.1557, Hits@3: 0.2501, Hits@10: 0.3569, MRR: 0.2209\n",
      "------------------------------------------------------------\n",
      "Iteration: 10210, Train loss: -0.8960, rewards: 0.1988\n",
      "Iteration: 10220, Train loss: -0.6334, rewards: 0.1400\n",
      "Iteration: 10230, Train loss: -0.6681, rewards: 0.1487\n",
      "Iteration: 10240, Train loss: -0.6077, rewards: 0.1581\n",
      "Iteration: 10250, Train loss: -0.6018, rewards: 0.1275\n",
      "Iteration: 10260, Train loss: -0.6557, rewards: 0.1569\n",
      "Iteration: 10270, Train loss: -0.5858, rewards: 0.1537\n",
      "Iteration: 10280, Train loss: -0.5942, rewards: 0.1163\n",
      "Iteration: 10290, Train loss: -0.6932, rewards: 0.1481\n",
      "Iteration: 10300, Train loss: -0.7424, rewards: 0.1600\n",
      "Eval:\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(options)\n",
    "#self = trainer\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c2ad0c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.agent.load_state_dict(torch.load(options['model_dir'] + 'agent.ckpt'))\n",
    "trainer.agent.eval()\n",
    "\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc9340fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print_paths = True\n",
    "# with torch.no_grad():\n",
    "\n",
    "#     batch_counter = 0\n",
    "#     paths = defaultdict(list)\n",
    "#     answers = []\n",
    "#     all_final_reward_1 = 0\n",
    "#     all_final_reward_3 = 0\n",
    "#     all_final_reward_5 = 0\n",
    "#     all_final_reward_10 = 0\n",
    "#     all_final_reward_20 = 0\n",
    "#     auc = 0\n",
    "\n",
    "#     total_examples = self.test_environment.total_no_examples\n",
    "\n",
    "#     for entity_episode in self.test_environment.get_episodes():\n",
    "#         batch_counter += 1\n",
    "\n",
    "#         temp_batch_size = entity_episode.no_examples\n",
    "\n",
    "#         self.qr = entity_episode.get_query_relation()\n",
    "#         query_relation = self.qr\n",
    "#         query_relation = torch.tensor(query_relation).long().to(self.device)\n",
    "#         # set initial beam probs\n",
    "#         beam_probs = torch.zeros((temp_batch_size * self.test_rollouts, 1)).to(self.device)\n",
    "\n",
    "#         # get initial state for entity agent\n",
    "#         entity_state = entity_episode.get_state()\n",
    "\n",
    "#         head, link, tail = self.get_graph(entity_episode, mode = 'test')\n",
    "#         next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "#         next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "#         current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "\n",
    "#         entity_state_emb = torch.zeros(1, 2, temp_batch_size * self.test_rollouts,\n",
    "#                                        self.agent.m * self.hidden_size).to(self.device)\n",
    "#         prev_relation = (torch.ones(temp_batch_size * self.test_rollouts) * self.relation_vocab[\n",
    "#             'DUMMY_START_RELATION']).long().to(self.device)\n",
    "\n",
    "#         if print_paths:\n",
    "#             self.entity_trajectory = [current_entities]\n",
    "#             self.relation_trajectory = [prev_relation]   \n",
    "\n",
    "#         self.log_probs = np.zeros((temp_batch_size * self.test_rollouts,)) * 1.0\n",
    "#         for i in range(self.path_length):\n",
    "#             break\n",
    "#         break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f512a9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# next_relations, next_entities, prev_state = next_relations, next_entities, entity_state_emb\n",
    "# prev_relation, query_embedding, current_entities = prev_relation, query_relation,current_entities\n",
    "# head, link, tail = head, link, tail"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70c5dc8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with torch.no_grad():\n",
    "    \n",
    "#     n_ent = len(self.agent.entity_embedding.weight)\n",
    "#     n_rel = len(self.agent.relation_embedding.weight)\n",
    "    \n",
    "#     self.agent.ent_embed = self.agent.entity_embedding(torch.arange(n_ent).to(head.device))\n",
    "#     self.agent.rel_embed = self.agent.relation_embedding(torch.arange(n_rel).to(link.device))\n",
    "\n",
    "#     for i in range(2):\n",
    "#         self.agent.ent_embed = self.agent.node_conv[i](head, link, tail, \n",
    "#                                                 self.agent.ent_embed, self.agent.rel_embed)\n",
    "#         self.agent.rel_embed = self.agent.rel_conv[i](head, link, tail, \n",
    "#                                                 self.agent.ent_embed, self.agent.rel_embed)\n",
    "\n",
    "#     prev_action_embedding = self.agent.action_encoder(prev_relation, current_entities) # (original batch_size * num_rollout, 4*self.agent.embedding_size)\n",
    "\n",
    "#     prev_state = torch.unbind(prev_state, dim=1)\n",
    "#     prev_state = [prev_state[0].squeeze(0), prev_state[1].squeeze(0)]\n",
    "\n",
    "#     new_prev_state = list()\n",
    "\n",
    "#     output, new_state = self.agent.policy_step(prev_action_embedding, prev_state)\n",
    "\n",
    "#     prev_entity = self.agent.ent_embed[current_entities]\n",
    "#     if self.agent.use_entity_embeddings:\n",
    "#         state = torch.cat([output, prev_entity], dim=-1)\n",
    "#     else:\n",
    "#         state = output\n",
    "\n",
    "#     candidate_action_embeddings = self.agent.action_encoder(next_relations, next_entities)\n",
    "#     query_embedding = self.agent.rel_embed[query_embedding]\n",
    "\n",
    "#     state_query_concat = torch.cat([state, query_embedding], dim=-1)\n",
    "\n",
    "#     # MLP for policy#\n",
    "\n",
    "#     output = self.agent.policy_mlp(state_query_concat)\n",
    "#     # print(output.size())\n",
    "#     output_expanded = torch.unsqueeze(output, dim=1)  # [original batch_size * num_rollout, 1, 2D], D=self.agent.hidden_size\n",
    "#     # print(output_expanded.size(), candidate_action_embeddings.size())\n",
    "#     prelim_scores = torch.sum(candidate_action_embeddings * output_expanded, dim=2)\n",
    "\n",
    "#     # Masking PAD actions\n",
    "\n",
    "#     comparison_tensor = torch.ones_like(next_relations).int() * self.agent.rPAD  # matrix to compare\n",
    "#     mask = next_relations == comparison_tensor  # The mask\n",
    "#     dummy_scores = torch.ones_like(prelim_scores) * -99999.0  # the base matrix to choose from if dummy relation\n",
    "#     scores = torch.where(mask, dummy_scores, prelim_scores)  # [original batch_size * num_rollout, max_num_actions]\n",
    "\n",
    "#     # 4 sample action\n",
    "#     action = torch.distributions.categorical.Categorical(logits=scores) # [original batch_size * num_rollout, 1]\n",
    "#     label_action = action.sample() # [original batch_size * num_rollout,]\n",
    "\n",
    "#     # loss\n",
    "#     # 5a.\n",
    "#     loss = torch.nn.CrossEntropyLoss(reduce=False)(scores, label_action)\n",
    "\n",
    "#     # 6. Map back to true id\n",
    "#     chosen_relation = next_relations[torch.arange(len(label_action)), label_action]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70c874d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# next_relations, next_entities, prev_state = next_possible_relations, next_possible_entities, entity_state_emb\n",
    "# prev_relation, query_embedding, current_entities = prev_relation, query_relation,current_entities\n",
    "# head, link, tail = head, link, tail"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc1d1658",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with torch.no_grad():\n",
    "#     self.agent.ent_embed = self.agent.entity_embedding(torch.arange(head.max().item() + 1).to(head.device))\n",
    "#     self.agent.rel_embed = self.agent.relation_embedding(torch.arange(link.max().item() + 1).to(link.device))\n",
    "\n",
    "#     for i in range(2):\n",
    "#         self.agent.ent_embed = self.agent.node_conv[i](head, link, tail, \n",
    "#                                                 self.agent.ent_embed, self.agent.rel_embed)\n",
    "#         self.agent.rel_embed = self.agent.rel_conv[i](head, link, tail, \n",
    "#                                                 self.agent.ent_embed, self.agent.rel_embed)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0aa96cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with torch.no_grad():\n",
    "#     prev_action_embedding = self.agent.action_encoder(prev_relation, current_entities) # (original batch_size * num_rollout, 4*self.embedding_size)\n",
    "\n",
    "#     prev_state = torch.unbind(prev_state, dim=1)\n",
    "#     prev_state = [prev_state[0].squeeze(0), prev_state[1].squeeze(0)]\n",
    "\n",
    "#     new_prev_state = list()\n",
    "\n",
    "#     output, new_state = self.agent.policy_step(prev_action_embedding, prev_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8763062b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with torch.no_grad():\n",
    "#     prev_entity = self.agent.ent_embed[current_entities]\n",
    "#     if self.agent.use_entity_embeddings:\n",
    "#         state = torch.cat([output, prev_entity], dim=-1)\n",
    "#     else:\n",
    "#         state = output\n",
    "\n",
    "#     candidate_action_embeddings = self.agent.action_encoder(next_relations, next_entities)\n",
    "#     query_embedding = self.agent.rel_embed[query_embedding]\n",
    "\n",
    "#     state_query_concat = torch.cat([state, query_embedding], dim=-1)\n",
    "\n",
    "#     # MLP for policy#\n",
    "\n",
    "#     output = self.agent.policy_mlp(state_query_concat)\n",
    "#     # print(output.size())\n",
    "#     output_expanded = torch.unsqueeze(output, dim=1)  # [original batch_size * num_rollout, 1, 2D], D=self.agent.hidden_size\n",
    "#     # print(output_expanded.size(), candidate_action_embeddings.size())\n",
    "#     prelim_scores = torch.sum(candidate_action_embeddings * output_expanded, dim=2)\n",
    "\n",
    "#     # Masking PAD actions\n",
    "\n",
    "#     comparison_tensor = torch.ones_like(next_relations).int() * self.agent.rPAD  # matrix to compare\n",
    "#     mask = next_relations == comparison_tensor  # The mask\n",
    "#     dummy_scores = torch.ones_like(prelim_scores) * -99999.0  # the base matrix to choose from if dummy relation\n",
    "#     scores = torch.where(mask, dummy_scores, prelim_scores)  # [original batch_size * num_rollout, max_num_actions]\n",
    "\n",
    "#     # 4 sample action\n",
    "#     action = torch.distributions.categorical.Categorical(logits=scores) # [original batch_size * num_rollout, 1]\n",
    "#     label_action = action.sample() # [original batch_size * num_rollout,]\n",
    "\n",
    "#     # loss\n",
    "#     # 5a.\n",
    "#     loss = torch.nn.CrossEntropyLoss(reduce=False)(scores, label_action)\n",
    "\n",
    "#     # 6. Map back to true id\n",
    "#     chosen_relation = next_relations[torch.arange(len(label_action)), label_action]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
