{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n",
    "\n",
    "from model.trainer import Trainer\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2e28869",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/cwong/miniconda3/envs/tracy/lib/python3.9/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from tqdm.notebook import tqdm\n",
    "import json\n",
    "import time\n",
    "import os\n",
    "import logging\n",
    "import numpy as np\n",
    "from model.agent import Agent\n",
    "from model.options import read_options\n",
    "from model.environment import env\n",
    "import codecs\n",
    "from collections import defaultdict\n",
    "import gc\n",
    "import resource\n",
    "import sys\n",
    "from model.baseline import ReactiveBaseline\n",
    "from scipy.special import logsumexp as lse\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam, SGD, AdamW\n",
    "from copy import deepcopy\n",
    "from model.nell_eval import nell_eval\n",
    "\n",
    "def get_logger(output_dir):\n",
    "    filename=output_dir +'train'\n",
    "    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter\n",
    "    logger = getLogger(__name__)\n",
    "    logger.setLevel(INFO)\n",
    "    handler1 = StreamHandler()\n",
    "    handler1.setFormatter(Formatter(\"%(message)s\"))\n",
    "    handler2 = FileHandler(filename=f\"{filename}.log\")\n",
    "    handler2.setFormatter(Formatter(\"%(message)s\"))\n",
    "    logger.addHandler(handler1)\n",
    "    logger.addHandler(handler2)\n",
    "    return logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fc64eafe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.utils as utils\n",
    "from torch_scatter import scatter\n",
    "\n",
    "class node_aggregation(nn.Module):\n",
    "    def __init__(self, m, embedding_size, hidden_size, layer = 2):\n",
    "        super(node_aggregation, self).__init__()\n",
    "        \n",
    "        self.m = m\n",
    "        self.embedding_size = embedding_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.layer = layer\n",
    "        \n",
    "        self.node_encoder = nn.Linear(2*embedding_size, 2*embedding_size)\n",
    "        self.link_encoder = nn.Linear(2*embedding_size, 2*embedding_size)\n",
    "        self.value_encoder = nn.Linear(2*2*embedding_size, 2*embedding_size)\n",
    "        \n",
    "        self.mlp = nn.Sequential(nn.Linear(2*embedding_size, 2*2*embedding_size), nn.Mish(),\n",
    "                                 nn.Linear(2*2*embedding_size, 2*embedding_size))\n",
    "        self.mlp_norm = nn.LayerNorm(2*embedding_size)\n",
    "        self.norm = nn.LayerNorm(2*embedding_size)\n",
    "\n",
    "    def forward(self, H, L, T, X, Z):\n",
    "        h_x = self.node_encoder(X[H])\n",
    "        l_x = self.link_encoder(Z[L])\n",
    "        v = self.value_encoder(torch.cat([X[T], Z[L]], -1))\n",
    "\n",
    "        score = (l_x*h_x).sum(-1)/np.sqrt(L.shape[-1])\n",
    "        norm = scatter(torch.exp(score), H.long(), dim = 0, reduce = 'sum')\n",
    "        norm = norm[H.long()]\n",
    "        att = torch.exp(score)/norm\n",
    "\n",
    "        x_ = scatter(v*att.unsqueeze(-1), H.long(), dim = 0, reduce = 'sum')\n",
    "        \n",
    "        x = X\n",
    "        x[:len(x_)] += x_\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        x = self.mlp_norm(x + self.mlp(x))\n",
    "        return x\n",
    "    \n",
    "class relation_aggregation(nn.Module):\n",
    "    def __init__(self, m, embedding_size, hidden_size, layer = 2):\n",
    "        super(relation_aggregation, self).__init__()\n",
    "        \n",
    "        self.m = m\n",
    "        self.embedding_size = embedding_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.layer = layer\n",
    "        \n",
    "        self.node_encoder = nn.Sequential(nn.Linear(2*embedding_size, 2*embedding_size), nn.Mish(),\n",
    "                                          nn.Linear(2*embedding_size, 2*embedding_size))\n",
    "        self.mlp = nn.Sequential(nn.Linear(2*embedding_size, 2*2*embedding_size), nn.Mish(),\n",
    "                                 nn.Linear(2*2*embedding_size, 2*embedding_size))\n",
    "        self.mlp_norm = nn.LayerNorm(2*embedding_size)\n",
    "        self.norm = nn.LayerNorm(2*embedding_size)\n",
    "\n",
    "    def forward(self, H, L, T, X, Z):\n",
    "        h_x = self.node_encoder(X[H])\n",
    "        z_ = scatter(h_x, L.long(), dim = 0, reduce = 'mean')\n",
    "        \n",
    "        z = Z\n",
    "        z[:len(z_)] += z_\n",
    "        z = self.norm(z)\n",
    "        \n",
    "        z = self.mlp_norm(z + self.mlp(z))\n",
    "        return z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "64ac7797",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.utils as utils\n",
    "\n",
    "class Policy_step(nn.Module):\n",
    "    def __init__(self, m, embedding_size, hidden_size):\n",
    "        super(Policy_step, self).__init__()\n",
    "\n",
    "        self.batch_norm = nn.BatchNorm1d(m * hidden_size)\n",
    "        self.lstm_cell = nn.LSTMCell(input_size= (m * embedding_size), \n",
    "                                     hidden_size= m * hidden_size)\n",
    "        self.l1 = nn.Linear(m * embedding_size,  m * embedding_size)\n",
    "        self.l2 = nn.Linear( m * hidden_size, m * hidden_size)\n",
    "        self.l3 = nn.Linear( m * hidden_size, m * hidden_size)\n",
    "\n",
    "    def forward(self, prev_gnn, prev_action, prev_state):\n",
    "        prev_action = torch.relu(self.l1(prev_action))\n",
    "        \n",
    "        #prev_input = torch.cat([prev_gnn, prev_action], -1)\n",
    "        prev_input = prev_action\n",
    "        output, ch = self.lstm_cell(prev_input, prev_state)\n",
    "        output = torch.relu(self.l2(output))\n",
    "        ch = torch.relu(self.l3(ch))\n",
    "\n",
    "        ch = torch.cat([output.unsqueeze(0).unsqueeze(0), ch.unsqueeze(0).unsqueeze(0)], dim=1)\n",
    "\n",
    "        return output, ch\n",
    "\n",
    "class Policy_mlp(nn.Module):\n",
    "    def __init__(self, hidden_size, m, embedding_size):\n",
    "        super(Policy_mlp, self).__init__()\n",
    "\n",
    "        self.hidden_size = hidden_size\n",
    "        self.m = m\n",
    "        self.embedding_size = embedding_size\n",
    "        self.mlp_l1 = nn.Linear(m * self.hidden_size + 2 * 2* embedding_size, m * self.hidden_size, bias=True)\n",
    "        self.mlp_l2 = nn.Linear(m * self.hidden_size, (m * self.embedding_size * 2), bias=True)\n",
    "\n",
    "    def forward(self, state_query):\n",
    "        # state_query = state_query.float()\n",
    "        hidden = torch.relu(self.mlp_l1(state_query))\n",
    "        output = torch.relu(self.mlp_l2(hidden))\n",
    "        return output\n",
    "\n",
    "class Agent(nn.Module):\n",
    "\n",
    "    def __init__(self, params):\n",
    "        super(Agent, self).__init__()\n",
    "        self.action_vocab_size = len(params['relation_vocab'])\n",
    "        self.entity_vocab_size = len(params['entity_vocab'])\n",
    "        self.embedding_size = params['embedding_size']\n",
    "        self.hidden_size = params['hidden_size']\n",
    "        self.ePAD = params['entity_vocab']['PAD']\n",
    "        self.rPAD = params['relation_vocab']['PAD']\n",
    "        self.use_entity_embeddings = params['use_entity_embeddings']\n",
    "        self.train_entity_embeddings = params['train_entity_embeddings']\n",
    "        self.train_relation_embeddings = params['train_relation_embeddings']\n",
    "        self.device = params['device']\n",
    "\n",
    "        if self.use_entity_embeddings:\n",
    "            if self.train_entity_embeddings:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size)\n",
    "            else:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                    False)\n",
    "            torch.nn.init.xavier_uniform_(self.entity_embedding.weight)\n",
    "        else:\n",
    "            if self.train_entity_embeddings:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size)\n",
    "            else:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                    False)\n",
    "            torch.nn.init.constant_(self.entity_embedding.weight, 0.0)\n",
    "\n",
    "        if self.train_relation_embeddings:\n",
    "            self.relation_embedding = nn.Embedding(self.action_vocab_size, 2 * self.embedding_size)\n",
    "        else:\n",
    "            self.relation_embedding = nn.Embedding(self.action_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                False)\n",
    "        torch.nn.init.xavier_uniform_(self.relation_embedding.weight)\n",
    "\n",
    "        # self.relation_embedding = params['pretrained_embeddings_relation']\n",
    "\n",
    "        # self.train_entities = params['train_entity_embeddings']\n",
    "        # self.train_relations = params['train_relation_embeddings']\n",
    "\n",
    "        self.num_rollouts = params['num_rollouts']\n",
    "        self.test_rollouts = params['test_rollouts']\n",
    "        self.LSTM_Layers = params['LSTM_layers']\n",
    "        self.batch_size = params['batch_size'] * params['num_rollouts']\n",
    "        self.dummy_start_label = (torch.ones(self.batch_size) * params['relation_vocab']['DUMMY_START_RELATION']).long()\n",
    "        # print(self.dummy_start_label.size())\n",
    "        self.entity_embedding_size = self.embedding_size\n",
    "        self.initial_gnn_state = nn.Parameter(torch.zeros((1, self.embedding_size*2*2)))\n",
    "        torch.nn.init.xavier_uniform_(self.initial_gnn_state)\n",
    "\n",
    "        if self.use_entity_embeddings:\n",
    "            self.m = 4\n",
    "        else:\n",
    "            self.m = 2\n",
    "\n",
    "        self.policy_step = Policy_step(m=self.m, embedding_size=self.embedding_size, hidden_size=self.hidden_size).to(self.device)\n",
    "        self.policy_mlp = Policy_mlp(self.hidden_size, self.m, self.embedding_size).to(self.device)\n",
    "        \n",
    "        self.node_conv = nn.ModuleList()\n",
    "        self.rel_conv = nn.ModuleList()\n",
    "        for i in range(1):\n",
    "            self.node_conv.append(node_aggregation(m=self.m, \n",
    "                                                   embedding_size=self.embedding_size, \n",
    "                                                   hidden_size=self.hidden_size).to(self.device))\n",
    "            self.rel_conv.append(relation_aggregation(m=self.m, \n",
    "                                                      embedding_size=self.embedding_size, \n",
    "                                                      hidden_size=self.hidden_size).to(self.device))\n",
    "        \n",
    "        self.gate1_linear = nn.Linear(2*self.hidden_size, 3*2*self.hidden_size)\n",
    "        self.gate2_linear = nn.Linear(2*self.hidden_size, 3*2*self.hidden_size)\n",
    "        \n",
    "        self.state_encoder = nn.Sequential(nn.Linear(self.m*self.hidden_size*2 + self.embedding_size*2,\n",
    "                                           self.embedding_size*2*2), nn.Mish(), \n",
    "                                           nn.LayerNorm(self.embedding_size*2*2))\n",
    "\n",
    "\n",
    "    def get_mem_shape(self):\n",
    "        return (self.LSTM_Layers, 2, None, self.m * self.hidden_size)\n",
    "\n",
    "\n",
    "    def action_encoder(self, next_relations, next_entities):\n",
    "        # relation_embedding = self.relation_embedding[next_relations.cpu().numpy()]\n",
    "        # entity_embedding = self.entity_embedding[next_entities.cpu().numpy()]\n",
    "        relation_embedding = self.rel_embed[next_relations]\n",
    "        entity_embedding = self.ent_embed[next_entities]\n",
    "\n",
    "        if self.use_entity_embeddings:\n",
    "            action_embedding = torch.cat([relation_embedding, entity_embedding], dim=-1)\n",
    "        else:\n",
    "            action_embedding = relation_embedding\n",
    "\n",
    "        return action_embedding\n",
    "    \n",
    "    def neighbour_aggregation(self, query_relation, prev_state, next_neighbors):\n",
    "        link = next_neighbors[:, :, :, 1]\n",
    "        tail = next_neighbors[:, :, :, 0]\n",
    "        mask = (link != 0).float()\n",
    "\n",
    "        t_embed = self.entity_embedding(tail)\n",
    "        r_embed = self.relation_embedding(link)\n",
    "\n",
    "        lstm_state = torch.cat([prev_state[0], prev_state[1]], -1)\n",
    "        query_embedding = self.relation_embedding(query_relation.long())\n",
    "        state = self.state_encoder(torch.cat([lstm_state, query_embedding], -1))\n",
    "        neighbor_embedding = torch.cat([t_embed, r_embed], -1)\n",
    "\n",
    "        att = (state.unsqueeze(1).unsqueeze(1)*neighbor_embedding).sum(-1)/np.sqrt(state.shape[-1])\n",
    "        att = F.softmax(att - (1 - mask)*1e8, 2)\n",
    "        update_embedding = (neighbor_embedding*att.unsqueeze(-1)).sum(2)\n",
    "        return update_embedding\n",
    "\n",
    "    def step(self, next_relations, next_entities, prev_state, prev_relation, query_embedding, current_entities,\n",
    "             head, link, tail, next_neighbors, prev_gnn_state):\n",
    "        \n",
    "        n_ent = len(self.entity_embedding.weight)\n",
    "        n_rel = len(self.relation_embedding.weight)\n",
    "\n",
    "        self.ent_embed = self.entity_embedding(torch.arange(n_ent).to(head.device))\n",
    "        self.rel_embed = self.relation_embedding(torch.arange(n_rel).to(link.device))\n",
    "\n",
    "        for i in range(1):\n",
    "            self.ent_embed = self.node_conv[i](head, link, tail, \n",
    "                                                    self.ent_embed, self.rel_embed)\n",
    "            self.rel_embed = self.rel_conv[i](head, link, tail, \n",
    "                                                    self.ent_embed, self.rel_embed)\n",
    "\n",
    "        prev_action_embedding = self.action_encoder(prev_relation, current_entities) # (original batch_size * num_rollout, 4*self.embedding_size)\n",
    "\n",
    "        prev_state = torch.unbind(prev_state, dim=1)\n",
    "        prev_state = [prev_state[0].squeeze(0), prev_state[1].squeeze(0)]\n",
    "\n",
    "        new_prev_state = list()\n",
    "\n",
    "        output, new_state = self.policy_step(prev_gnn_state, prev_action_embedding, prev_state)\n",
    "\n",
    "        prev_entity = self.ent_embed[current_entities]\n",
    "        if self.use_entity_embeddings:\n",
    "            state = torch.cat([output, prev_entity], dim=-1)\n",
    "        else:\n",
    "            state = output\n",
    "\n",
    "        candidate_action_embeddings = self.action_encoder(next_relations, next_entities)\n",
    "        gnn_embedding = self.neighbour_aggregation(query_embedding, prev_state, next_neighbors)\n",
    "        candidate_action_embeddings = torch.cat([candidate_action_embeddings, gnn_embedding], -1)\n",
    "\n",
    "        query_embedding = self.rel_embed[query_embedding]\n",
    "        state_query_concat = torch.cat([state, query_embedding], dim=-1)\n",
    "\n",
    "        # MLP for policy#\n",
    "\n",
    "        output = self.policy_mlp(state_query_concat)\n",
    "        # print(output.size())\n",
    "        output_expanded = torch.unsqueeze(output, dim=1)  # [original batch_size * num_rollout, 1, 2D], D=self.hidden_size\n",
    "        # print(output_expanded.size(), candidate_action_embeddings.size())\n",
    "        prelim_scores = torch.sum(candidate_action_embeddings * output_expanded, dim=2)\n",
    "\n",
    "        # Masking PAD actions\n",
    "\n",
    "        comparison_tensor = torch.ones_like(next_relations).int() * self.rPAD  # matrix to compare\n",
    "        mask = next_relations == comparison_tensor  # The mask\n",
    "        dummy_scores = torch.ones_like(prelim_scores) * -99999.0  # the base matrix to choose from if dummy relation\n",
    "        scores = torch.where(mask, dummy_scores, prelim_scores)  # [original batch_size * num_rollout, max_num_actions]\n",
    "\n",
    "        # 4 sample action\n",
    "        action = torch.distributions.categorical.Categorical(logits=scores) # [original batch_size * num_rollout, 1]\n",
    "        label_action = action.sample() # [original batch_size * num_rollout,]\n",
    "\n",
    "        # loss\n",
    "        # 5a.\n",
    "        loss = torch.nn.CrossEntropyLoss(reduce=False)(scores, label_action)\n",
    "\n",
    "        # 6. Map back to true id\n",
    "        chosen_relation = next_relations[torch.arange(len(label_action)), label_action]\n",
    "\n",
    "        return loss, new_state, F.log_softmax(scores), label_action, chosen_relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a5a4c2e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import time\n",
    "import os\n",
    "import logging\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "\n",
    "#from model.agent import Agent\n",
    "from model.options import read_options\n",
    "from model.environment import env\n",
    "import codecs\n",
    "from collections import defaultdict\n",
    "import gc\n",
    "import resource\n",
    "import sys\n",
    "from model.baseline import ReactiveBaseline\n",
    "from scipy.special import logsumexp as lse\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from model.nell_eval import nell_eval\n",
    "\n",
    "logger = logging.getLogger()\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)\n",
    "\n",
    "\n",
    "class Trainer(object):\n",
    "    def __init__(self, params):\n",
    "\n",
    "        # transfer parameters to self\n",
    "        for key, val in params.items(): setattr(self, key, val);\n",
    "        self.device = params['device']\n",
    "        print(self.device)\n",
    "        self.agent = Agent(params).to(self.device)\n",
    "        #self.c_agent = ClusterAgent(params).to(self.device)\n",
    "        self.model_dir = params['model_dir']\n",
    "        self.save_path = self.model_dir + \"model\" + '.ckpt'\n",
    "        self.train_environment = env(params, 'train')\n",
    "        self.dev_test_environment = env(params, 'dev')\n",
    "        self.test_test_environment = env(params, 'test')\n",
    "        self.test_environment = self.dev_test_environment\n",
    "        self.rev_relation_vocab = self.train_environment.grapher.rev_relation_vocab\n",
    "        self.rev_entity_vocab = self.train_environment.grapher.rev_entity_vocab\n",
    "        #self.rev_cluster_relation_vocab = self.train_environment.cluster_grapher.rev_cluster_relation_vocab\n",
    "        #self.rev_cluster_vocab = self.train_environment.cluster_grapher.rev_cluster_vocab\n",
    "\n",
    "        self.max_hits_at_10 = 0\n",
    "        self.ePAD = self.entity_vocab['PAD']\n",
    "        self.rPAD = self.relation_vocab['PAD']\n",
    "        self.decaying_beta_init = self.beta\n",
    "        # optimize\n",
    "        self.baseline = ReactiveBaseline(params, self.Lambda)\n",
    "\n",
    "        self.decay_batch = None\n",
    "        self.gamma = params['gamma']\n",
    "        self.grad_clip_norm = params['grad_clip_norm']\n",
    "        self.eval_every = params['eval_every']\n",
    "        self.total_iterations = params['total_iterations']\n",
    "        self.learning_rate = params['learning_rate']\n",
    "        self.pool = params['pool']\n",
    "        self.output_dir = params['output_dir']\n",
    "        self.params = params\n",
    "        \n",
    "        self.positive_reward_rates = []\n",
    "        self.optimizer = optim.Adam(list(self.agent.parameters()),\n",
    "                                    lr=self.learning_rate)\n",
    "        self.two_embeds_sim_criterion = torch.nn.KLDivLoss()\n",
    "\n",
    "    def calc_reinforce_loss(self, all_loss, all_logits, cum_discounted_reward, decaying_beta, baseline):\n",
    "\n",
    "        loss = torch.stack(all_loss, dim=1)  # [original batch_size * num_rollout, T]\n",
    "        base_value = baseline.get_baseline_value()\n",
    "\n",
    "        # multiply with rewards\n",
    "        final_reward = cum_discounted_reward - base_value\n",
    "        reward_mean = torch.mean(final_reward)\n",
    "\n",
    "        # Constant added for numerical stability\n",
    "        reward_std = torch.std(final_reward) + 1e-6\n",
    "        final_reward = torch.div(final_reward - reward_mean, reward_std)\n",
    "\n",
    "        loss = torch.mul(loss, final_reward)  # [original batch_size * num_rollout, T]\n",
    "\n",
    "        entropy_loss = decaying_beta * self.entropy_reg_loss(all_logits)\n",
    "\n",
    "        total_loss = torch.mean(loss) - entropy_loss  # scalar\n",
    "\n",
    "        return total_loss\n",
    "    \n",
    "    def entropy_reg_loss(self, all_logits):  # control diversity\n",
    "        all_logits = torch.stack(all_logits, dim=2)  # [original batch_size * num_rollout, max_num_actions, T]\n",
    "        entropy_loss = - torch.mean(torch.sum(torch.mul(torch.exp(all_logits), all_logits), dim=1))  # scalar\n",
    "        return entropy_loss\n",
    "\n",
    "    def calc_cum_discounted_reward(self, rewards):\n",
    "\n",
    "        running_add = torch.zeros([rewards.size(0)]).to(self.device)  # [original batch_size * num_rollout]\n",
    "        cum_disc_reward = torch.zeros([rewards.size(0), self.path_length]).to(\n",
    "            self.device)  # [original batch_size * num_rollout, T]\n",
    "        cum_disc_reward[:,\n",
    "        self.path_length - 1] = rewards  # set the last time step to the reward received at the last state\n",
    "        for t in reversed(range(self.path_length)):\n",
    "            running_add = self.gamma * running_add + cum_disc_reward[:, t]\n",
    "            cum_disc_reward[:, t] = running_add\n",
    "        return cum_disc_reward\n",
    "\n",
    "    def calc_cum_discounted_reward_credit(self, entity_rewards):\n",
    "\n",
    "        num_instances = entity_rewards.size(0)\n",
    "        running_add = torch.zeros([num_instances]).to(self.device)  # [original batch_size * num_rollout]\n",
    "        cum_disc_reward = torch.zeros([num_instances, self.path_length]).to(\n",
    "            self.device)  # [original batch_size * num_rollout, T]\n",
    "        cum_disc_reward[:,\n",
    "        self.path_length - 1] = entity_rewards  # set the last time step to the reward received at the last state\n",
    "\n",
    "        for t in reversed(range(1, self.path_length)):\n",
    "            running_add = self.gamma * running_add + cum_disc_reward[:, t] # approx_credits[t].to(self.device) * cluster_rewards\n",
    "            cum_disc_reward[:, t-1] = running_add\n",
    "\n",
    "        return cum_disc_reward\n",
    "\n",
    "    def get_graph(self, entity_episode, mode = 'train'):\n",
    "        if mode == 'train':\n",
    "            graph = deepcopy(self.train_environment.grapher)\n",
    "        else:\n",
    "            graph = deepcopy(self.test_environment.grapher)\n",
    "        \n",
    "        if mode == 'train':\n",
    "            ent_match = graph.array_store[entity_episode.start_entities,:,0] == np.expand_dims(entity_episode.end_entities, -1)\n",
    "            rel_match = graph.array_store[entity_episode.start_entities,:,1] == np.expand_dims(entity_episode.query_relation, -1)\n",
    "\n",
    "            tmp = graph.array_store[entity_episode.start_entities, :, :]\n",
    "            tmp[rel_match&ent_match] = 0\n",
    "            graph.array_store[entity_episode.start_entities, :, :] = tmp\n",
    "\n",
    "        head = torch.repeat_interleave(torch.arange(len(graph.rev_entity_vocab)), 40)\n",
    "        tail = torch.LongTensor(graph.array_store[:, :, 0].reshape(-1))\n",
    "        link = torch.LongTensor(graph.array_store[:, :, 1].reshape(-1))\n",
    "\n",
    "        head = head[(link > 0)].to(self.device)\n",
    "        tail = tail[(link > 0)].to(self.device)\n",
    "        link = link[(link > 0)].to(self.device)\n",
    "        return head, link, tail\n",
    "    \n",
    "    def train(self):\n",
    "        train_loss = []\n",
    "        train_reward = []\n",
    "\n",
    "        start_time = time.time()\n",
    "        self.batch_counter = 0\n",
    "        current_decay = self.decaying_beta_init\n",
    "        current_decay_count = 0\n",
    "\n",
    "        print('Agent start learning ...')\n",
    "        for entity_episode in self.train_environment.get_episodes():\n",
    "\n",
    "            self.batch_counter += 1\n",
    "\n",
    "            current_decay_count += 1\n",
    "            if current_decay_count == self.decay_batch:\n",
    "                current_decay *= self.decay_rate\n",
    "                current_decay_count = 0\n",
    "\n",
    "            # get initial state for entity agent\n",
    "            head, link, tail = self.get_graph(entity_episode, mode = 'train')\n",
    "            entity_state_emb = torch.zeros(1, 2, self.batch_size * self.num_rollouts,\n",
    "                                           self.agent.m * self.hidden_size).to(self.device)\n",
    "            entity_state = entity_episode.get_state()\n",
    "            next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(\n",
    "                self.device)  # original batch_size * num_rollout, max_num_actions\n",
    "            next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "\n",
    "            # range_arr = torch.arange(self.batch_size * self.num_rollouts).to(self.device)\n",
    "            prev_relation = self.agent.dummy_start_label.to(self.device)  # original batch_size * num_rollout, 1-D, (1...)\n",
    "\n",
    "            query_relation = entity_episode.get_query_relation()\n",
    "            query_relation = torch.tensor(query_relation).long().to(self.device)\n",
    "            current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "            prev_gnn_state = self.agent.initial_gnn_state.repeat(len(current_entities), 1)\n",
    "            #prev_entities = current_entities.clone()\n",
    "\n",
    "            all_losses = []\n",
    "            all_logits = []\n",
    "            all_action_id = []\n",
    "            path = [current_entities]\n",
    "\n",
    "            for i in range(self.path_length):\n",
    "                next_neighbors = deepcopy(self.train_environment.grapher.array_store)\n",
    "                in_sample = next_neighbors[entity_episode.start_entities]\n",
    "                end_match = in_sample[:, :, 0] == np.expand_dims(entity_episode.end_entities, -1)\n",
    "                rel_match = in_sample[:, :, 1] == np.expand_dims(entity_episode.query_relation, -1)\n",
    "                in_sample[end_match&rel_match] = 0\n",
    "                next_neighbors[entity_episode.start_entities] = in_sample\n",
    "\n",
    "                next_neighbors = next_neighbors[next_possible_entities.cpu().numpy()]\n",
    "                next_neighbors = torch.LongTensor(next_neighbors).to(current_entities.device)\n",
    "\n",
    "                loss, entity_state_emb, logits, idx, chosen_relation = self.agent.step(\n",
    "                    next_possible_relations,\n",
    "                    next_possible_entities, entity_state_emb,\n",
    "                    prev_relation, query_relation,\n",
    "                    current_entities,\n",
    "                    head, link, tail,\n",
    "                    next_neighbors, prev_gnn_state\n",
    "                )\n",
    "\n",
    "                entity_state = entity_episode(idx.cpu())\n",
    "                next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "                prev_relation = chosen_relation.to(self.device)\n",
    "\n",
    "                all_losses.append(loss)\n",
    "                all_logits.append(logits)\n",
    "                all_action_id.append(idx)\n",
    "                path.append(current_entities)\n",
    "\n",
    "            rewards = entity_episode.get_reward()\n",
    "            rewards = torch.tensor(rewards).to(self.device)\n",
    "\n",
    "            cum_discounted_reward = self.calc_cum_discounted_reward(rewards)\n",
    "            reinforce_loss = self.calc_reinforce_loss(all_losses, all_logits, cum_discounted_reward,\n",
    "                                                        current_decay, self.baseline)\n",
    "\n",
    "            self.baseline.update(torch.mean(cum_discounted_reward))\n",
    "\n",
    "            self.optimizer.zero_grad()\n",
    "            reinforce_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=self.grad_clip_norm, norm_type=2)\n",
    "            self.optimizer.step()\n",
    "\n",
    "            train_loss.append(reinforce_loss.detach().cpu().item())\n",
    "            train_reward.append(rewards.cpu().float().tolist())\n",
    "\n",
    "            if (self.batch_counter > 0)&(self.batch_counter % (self.eval_every//10) == 0):\n",
    "                avg_loss = np.mean(train_loss[-(self.eval_every//10):])\n",
    "                avg_reward = np.mean(sum(train_reward[-(self.eval_every//10):], []))\n",
    "                print('Iteration: {}, Train loss: {:.4f}, rewards: {:.4f}'.format(self.batch_counter, avg_loss, avg_reward))\n",
    "                gc.collect()\n",
    "\n",
    "            if (self.batch_counter > 0)&(self.batch_counter % self.eval_every == 0):\n",
    "                print('Eval:')\n",
    "                self.test(beam = True)\n",
    "                gc.collect()\n",
    "                print('------------------------------------------------------------')\n",
    "\n",
    "            if self.batch_counter > self.total_iterations:\n",
    "                break\n",
    "\n",
    "    def test(self, beam=False, print_paths=False, save_model=True):\n",
    "\n",
    "        with torch.no_grad():\n",
    "\n",
    "            batch_counter = 0\n",
    "            paths = defaultdict(list)\n",
    "            answers = []\n",
    "            all_final_reward_1 = 0\n",
    "            all_final_reward_3 = 0\n",
    "            all_final_reward_5 = 0\n",
    "            all_final_reward_10 = 0\n",
    "            all_final_reward_20 = 0\n",
    "            auc = 0\n",
    "\n",
    "            total_examples = self.test_environment.total_no_examples\n",
    "\n",
    "            for entity_episode in self.test_environment.get_episodes():\n",
    "                batch_counter += 1\n",
    "\n",
    "                temp_batch_size = entity_episode.no_examples\n",
    "\n",
    "                self.qr = entity_episode.get_query_relation()\n",
    "                query_relation = self.qr\n",
    "                query_relation = torch.tensor(query_relation).long().to(self.device)\n",
    "                # set initial beam probs\n",
    "                beam_probs = torch.zeros((temp_batch_size * self.test_rollouts, 1)).to(self.device)\n",
    "\n",
    "                # get initial state for entity agent\n",
    "                entity_state = entity_episode.get_state()\n",
    "                \n",
    "                head, link, tail = self.get_graph(entity_episode, mode = 'test')\n",
    "                next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "\n",
    "                entity_state_emb = torch.zeros(1, 2, temp_batch_size * self.test_rollouts,\n",
    "                                               self.agent.m * self.hidden_size).to(self.device)\n",
    "                prev_relation = (torch.ones(temp_batch_size * self.test_rollouts) * self.relation_vocab[\n",
    "                    'DUMMY_START_RELATION']).long().to(self.device)\n",
    "                prev_gnn_state = self.agent.initial_gnn_state.repeat(len(current_entities), 1)\n",
    "\n",
    "                if print_paths:\n",
    "                    self.entity_trajectory = [current_entities]\n",
    "                    self.relation_trajectory = [prev_relation]   \n",
    "\n",
    "                self.log_probs = np.zeros((temp_batch_size * self.test_rollouts,)) * 1.0\n",
    "                for i in range(self.path_length):\n",
    "                    \n",
    "                    next_neighbors = self.test_environment.grapher.array_store[next_entities.cpu().numpy()]\n",
    "                    next_neighbors = torch.LongTensor(next_neighbors).to(current_entities.device)\n",
    "                    \n",
    "                    loss, entity_state_emb, test_scores, test_action_idx, chosen_relation = self.agent.step(\n",
    "                        next_relations,\n",
    "                        next_entities, entity_state_emb,\n",
    "                        prev_relation, query_relation,\n",
    "                        current_entities,\n",
    "                        head, link, tail,\n",
    "                        next_neighbors, prev_gnn_state\n",
    "                    )\n",
    "                    \n",
    "                    #Mimic original implementation on pytorch\n",
    "                    if beam:\n",
    "                        k = self.test_rollouts\n",
    "                        beam_probs = beam_probs.to(self.device)\n",
    "                        new_scores = test_scores + beam_probs\n",
    "                        new_scores = new_scores.cpu()\n",
    "                        if i == 0:\n",
    "                            idx = np.argsort(new_scores)\n",
    "                            idx = idx[:, -k:]\n",
    "                            ranged_idx = np.tile([b for b in range(k)], temp_batch_size)\n",
    "                            idx = idx[np.arange(k * temp_batch_size), ranged_idx]\n",
    "                        else:\n",
    "                            idx = self.top_k(new_scores, k)\n",
    "\n",
    "                        y = idx // self.max_num_actions\n",
    "                        x = idx % self.max_num_actions\n",
    "\n",
    "                        y += np.repeat([b * k for b in range(temp_batch_size)], k)\n",
    "                        entity_state['current_entities'] = entity_state['current_entities'][y]\n",
    "                        entity_state['next_relations'] = entity_state['next_relations'][y, :]\n",
    "                        entity_state['next_entities'] = entity_state['next_entities'][y, :]\n",
    "                        entity_state_emb = entity_state_emb[:, :, y, :]\n",
    "\n",
    "                        test_action_idx = x\n",
    "                        chosen_relation = entity_state['next_relations'][np.arange(temp_batch_size * k), x]\n",
    "\n",
    "                        beam_probs = new_scores[y, x]\n",
    "                        beam_probs = beam_probs.reshape((-1, 1))\n",
    "\n",
    "#                     #My implementation to fit arbitrary dimension\n",
    "#                     if beam:\n",
    "#                         k = self.test_rollouts\n",
    "#                         beam_probs = beam_probs.to(self.device)\n",
    "#                         new_scores = test_scores + beam_probs\n",
    "#                         new_scores = new_scores.cpu()\n",
    "#                         if i == 0:\n",
    "#                             reshape_score = new_scores.reshape(temp_batch_size, self.test_rollouts, -1)\n",
    "#                             possible_idx = []\n",
    "#                             for x in reshape_score:\n",
    "#                                 possible_idx.append(torch.LongTensor(np.where(x[0].cpu() > -1000)[0]))\n",
    "#                             idx = []\n",
    "#                             for x in possible_idx:\n",
    "#                                 idx.append(torch.cat(([x]*(self.test_rollouts//len(x) + 1)))[:self.test_rollouts])\n",
    "#                             idx = torch.cat(idx, 0)\n",
    "#                         else:\n",
    "#                             idx = self.top_k(new_scores, k)\n",
    "\n",
    "#                         y = idx // self.max_num_actions\n",
    "#                         x = idx % self.max_num_actions\n",
    "\n",
    "#                         y += np.repeat([b * k for b in range(temp_batch_size)], k)\n",
    "#                         entity_state['current_entities'] = entity_state['current_entities'][y]\n",
    "#                         entity_state['next_relations'] = entity_state['next_relations'][y, :]\n",
    "#                         entity_state['next_entities'] = entity_state['next_entities'][y, :]\n",
    "#                         entity_state_emb = entity_state_emb[:, :, y, :]\n",
    "\n",
    "#                         test_action_idx = x\n",
    "#                         chosen_relation = entity_state['next_relations'][np.arange(temp_batch_size * k), x]\n",
    "#                         beam_probs = new_scores[y, x]\n",
    "#                         beam_probs = beam_probs.reshape((-1, 1))\n",
    "                        \n",
    "                        if print_paths:\n",
    "                            for j in range(i):\n",
    "                                self.entity_trajectory[j] = self.entity_trajectory[j][y]\n",
    "                                self.relation_trajectory[j] = self.relation_trajectory[j][y]\n",
    "\n",
    "                    entity_state = entity_episode(test_action_idx.cpu().numpy())\n",
    "                    next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                    next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                    current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "                    prev_relation = torch.tensor(chosen_relation).long().to(self.device)\n",
    "                    \n",
    "                    if print_paths:\n",
    "                        self.entity_trajectory.append(entity_state['current_entities'])\n",
    "                        self.relation_trajectory.append(chosen_relation)\n",
    "\n",
    "                    test_scores = test_scores.cpu().numpy()\n",
    "                    self.log_probs += test_scores[np.arange(self.log_probs.shape[0]), test_action_idx.cpu().numpy()]\n",
    "\n",
    "                if beam:\n",
    "                    self.log_probs = beam_probs\n",
    "\n",
    "                rewards = entity_episode.get_reward()  # [B*test_rollouts]\n",
    "                reward_reshape = np.reshape(rewards, (temp_batch_size, self.test_rollouts))  # [orig_batch, test_rollouts]\n",
    "                self.log_probs = np.reshape(self.log_probs, (temp_batch_size, self.test_rollouts))\n",
    "                sorted_indx = np.argsort(-self.log_probs)\n",
    "                final_reward_1 = 0\n",
    "                final_reward_3 = 0\n",
    "                final_reward_5 = 0\n",
    "                final_reward_10 = 0\n",
    "                final_reward_20 = 0\n",
    "                AP = 0\n",
    "                ce = entity_episode.state['current_entities'].reshape((temp_batch_size, self.test_rollouts))\n",
    "                se = entity_episode.start_entities.reshape((temp_batch_size, self.test_rollouts))\n",
    "                for b in range(temp_batch_size):\n",
    "                    answer_pos = None\n",
    "                    seen = set()\n",
    "                    pos=0\n",
    "                    if self.pool == 'max':\n",
    "                        for r in sorted_indx[b]:\n",
    "                            if reward_reshape[b,r] == self.positive_reward:\n",
    "                                answer_pos = pos\n",
    "                                break\n",
    "                            if ce[b, r] not in seen:\n",
    "                                seen.add(ce[b, r])\n",
    "                                pos += 1\n",
    "                    if self.pool == 'sum':\n",
    "                        scores = defaultdict(list)\n",
    "                        answer = ''\n",
    "                        for r in sorted_indx[b]:\n",
    "                            scores[ce[b,r]].append(self.log_probs[b,r])\n",
    "                            if reward_reshape[b,r] == self.positive_reward:\n",
    "                                answer = ce[b,r]\n",
    "                        final_scores = defaultdict(float)\n",
    "                        for e in scores:\n",
    "                            final_scores[e] = lse(scores[e])\n",
    "                        sorted_answers = sorted(final_scores, key=final_scores.get, reverse=True)\n",
    "                        if answer in  sorted_answers:\n",
    "                            answer_pos = sorted_answers.index(answer)\n",
    "                        else:\n",
    "                            answer_pos = None\n",
    "\n",
    "\n",
    "                    if answer_pos != None:\n",
    "                        if answer_pos < 20:\n",
    "                            final_reward_20 += 1\n",
    "                            if answer_pos < 10:\n",
    "                                final_reward_10 += 1\n",
    "                                if answer_pos < 5:\n",
    "                                    final_reward_5 += 1\n",
    "                                    if answer_pos < 3:\n",
    "                                        final_reward_3 += 1\n",
    "                                        if answer_pos < 1:\n",
    "                                            final_reward_1 += 1\n",
    "                    if answer_pos == None:\n",
    "                        AP += 0\n",
    "                    else:\n",
    "                        AP += 1.0/((answer_pos+1))\n",
    "                    \n",
    "                    if print_paths:\n",
    "                        qr = self.train_environment.grapher.rev_relation_vocab[self.qr[b * self.test_rollouts]]\n",
    "                        start_e = self.rev_entity_vocab[entity_episode.start_entities[b * self.test_rollouts]]\n",
    "                        end_e = self.rev_entity_vocab[entity_episode.end_entities[b * self.test_rollouts]]\n",
    "                        paths[str(qr)].append(str(start_e) + \"\\t\" + str(end_e) + \"\\n\")\n",
    "                        paths[str(qr)].append(\"Reward:\" + str(1 if answer_pos != None and answer_pos < 10 else 0) + \"\\n\")\n",
    "                        for r in sorted_indx[b]:\n",
    "                            indx = b * self.test_rollouts + r\n",
    "                            if rewards[indx] == self.positive_reward:\n",
    "                                rev = 1\n",
    "                            else:\n",
    "                                rev = -1\n",
    "                            answers.append(self.rev_entity_vocab[se[b,r]]+'\\t'+ self.rev_entity_vocab[ce[b,r]]+'\\t'+ str(self.log_probs[b,r])+'\\n')\n",
    "                            paths[str(qr)].append(\n",
    "                                '\\t'.join([str(self.rev_entity_vocab[e[indx]]) for e in\n",
    "                                           self.entity_trajectory]) + '\\n' + '\\t'.join(\n",
    "                                    [str(self.rev_relation_vocab[re[indx]]) for re in self.relation_trajectory]) + '\\n' + str(\n",
    "                                    rev) + '\\n' + str(\n",
    "                                    self.log_probs[b, r]) + '\\n___' + '\\n')\n",
    "                        paths[str(qr)].append(\"#####################\\n\")\n",
    "\n",
    "                all_final_reward_1 += final_reward_1\n",
    "                all_final_reward_3 += final_reward_3\n",
    "                all_final_reward_5 += final_reward_5\n",
    "                all_final_reward_10 += final_reward_10\n",
    "                all_final_reward_20 += final_reward_20\n",
    "                auc += AP\n",
    "\n",
    "            all_final_reward_1 /= total_examples\n",
    "            all_final_reward_3 /= total_examples\n",
    "            all_final_reward_5 /= total_examples\n",
    "            all_final_reward_10 /= total_examples\n",
    "            all_final_reward_20 /= total_examples\n",
    "            auc /= total_examples\n",
    "            \n",
    "            if save_model:\n",
    "                if all_final_reward_10 >= self.max_hits_at_10:\n",
    "                    self.max_hits_at_10 = all_final_reward_10\n",
    "                    torch.save(self.agent.state_dict(), self.model_dir + \"agent\" + '.ckpt')\n",
    "                    # self.save_path = self.model_dir + \"model\" + '.ckpt'\n",
    "\n",
    "            if print_paths:\n",
    "                logger.info(\"[ printing paths at {} ]\".format(self.output_dir + '/test_beam/'))\n",
    "                for q in paths:\n",
    "                    j = q.replace('/', '-')\n",
    "                    with codecs.open(self.path_logger_file_ + '_' + j, 'a', 'utf-8') as pos_file:\n",
    "                        for p in paths[q]:\n",
    "                            pos_file.write(p)\n",
    "                with open(self.path_logger_file_ + 'answers', 'w') as answer_file:\n",
    "                    for a in answers:\n",
    "                        answer_file.write(a)\n",
    "\n",
    "            with open(self.output_dir + '/scores.txt', 'a') as score_file:\n",
    "                score_file.write(\"Hits@1: {:.4f}\".format(all_final_reward_1))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@3: {:.4f}\".format(all_final_reward_3))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@5: {:.4f}\".format(all_final_reward_5))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@10: {:.4f}\".format(all_final_reward_10))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@20: {:.4f}\".format(all_final_reward_20))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"MRR: {:.4f}\".format(auc))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"------------------------------------\")\n",
    "\n",
    "            print(\"Hits@1: {:.4f}, Hits@3: {:.4f}, Hits@10: {:.4f}, MRR: {:.4f}\".format(all_final_reward_1, \n",
    "                                                                                 all_final_reward_3,\n",
    "                                                                                 all_final_reward_10, auc))\n",
    "            \n",
    "    def top_k(self, scores, k):\n",
    "        scores = scores.reshape(-1, k * self.max_num_actions)  # [B, (k*max_num_actions)]\n",
    "        idx = np.argsort(scores, axis=1)\n",
    "        idx = idx[:, -k:]  # take the last k highest indices # [B , k]\n",
    "        return idx.reshape((-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8bca5f0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = {}\n",
    "\n",
    "#basic setting\n",
    "options['use_cuda'] = True\n",
    "options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/vocab/'\n",
    "options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/'\n",
    "options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "options['model_dir'] = './outputs_WN18RR_v7/'\n",
    "options['output_dir'] = './outputs_WN18RR_v7/'\n",
    "\n",
    "#agent setting\n",
    "options['pretrained_embeddings_relation'] = {}\n",
    "options['pretrained_embeddings_entity'] = {}\n",
    "options['embedding_size'] = 50\n",
    "options['hidden_size'] = 200\n",
    "options['use_entity_embeddings'] = 1\n",
    "options['train_entity_embeddings'] = 0\n",
    "options['train_relation_embeddings'] = 1\n",
    "options['path_length'] = 3\n",
    "options['LSTM_layers'] = 1\n",
    "options['max_num_actions'] = 40\n",
    "\n",
    "#hyperparameters\n",
    "options['test_rollouts'] = 40\n",
    "options['num_rollouts'] = 20\n",
    "options['batch_size'] = 64\n",
    "options['beta'] = 0.05\n",
    "options['Lambda'] = 0.05\n",
    "options['gamma'] = 1\n",
    "options['positive_reward'] = 1\n",
    "options['negative_reward'] = 0\n",
    "options['learning_rate'] = 5e-4\n",
    "options['grad_clip_norm'] = 100\n",
    "options['eval_every'] = 100\n",
    "options['total_iterations'] = 4000\n",
    "options['pool'] = 'max'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8dc1e4ea",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Agent start learning ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/cwong/miniconda3/envs/tracy/lib/python3.9/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/tmp/ipykernel_1595732/1090773659.py:224: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 10, Train loss: -0.3004, rewards: 0.2221\n",
      "Iteration: 20, Train loss: -0.3464, rewards: 0.2994\n",
      "Iteration: 30, Train loss: -0.2261, rewards: 0.3577\n",
      "Iteration: 40, Train loss: -0.3413, rewards: 0.3743\n",
      "Iteration: 50, Train loss: -0.3360, rewards: 0.3600\n",
      "Iteration: 60, Train loss: -0.4042, rewards: 0.3753\n",
      "Iteration: 70, Train loss: -0.4214, rewards: 0.3751\n",
      "Iteration: 80, Train loss: -0.3168, rewards: 0.3656\n",
      "Iteration: 90, Train loss: -0.4103, rewards: 0.3783\n",
      "Iteration: 100, Train loss: -0.4269, rewards: 0.3978\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1595732/1123405777.py:327: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.3883, Hits@3: 0.4446, Hits@10: 0.4957, MRR: 0.4238\n",
      "------------------------------------------------------------\n",
      "Iteration: 110, Train loss: -0.3818, rewards: 0.3802\n",
      "Iteration: 120, Train loss: -0.4457, rewards: 0.3943\n",
      "Iteration: 130, Train loss: -0.4553, rewards: 0.3663\n",
      "Iteration: 140, Train loss: -0.4921, rewards: 0.3981\n",
      "Iteration: 150, Train loss: -0.4978, rewards: 0.3864\n",
      "Iteration: 160, Train loss: -0.4860, rewards: 0.3934\n",
      "Iteration: 170, Train loss: -0.4214, rewards: 0.4292\n",
      "Iteration: 180, Train loss: -0.4448, rewards: 0.3976\n",
      "Iteration: 190, Train loss: -0.5133, rewards: 0.4263\n",
      "Iteration: 200, Train loss: -0.4278, rewards: 0.4339\n",
      "Eval:\n",
      "Hits@1: 0.3665, Hits@3: 0.4436, Hits@10: 0.5020, MRR: 0.4134\n",
      "------------------------------------------------------------\n",
      "Iteration: 210, Train loss: -0.4128, rewards: 0.3700\n",
      "Iteration: 220, Train loss: -0.5215, rewards: 0.4150\n",
      "Iteration: 230, Train loss: -0.4948, rewards: 0.3946\n",
      "Iteration: 240, Train loss: -0.4902, rewards: 0.4209\n",
      "Iteration: 250, Train loss: -0.5110, rewards: 0.4132\n",
      "Iteration: 260, Train loss: -0.4503, rewards: 0.3890\n",
      "Iteration: 270, Train loss: -0.5127, rewards: 0.4045\n",
      "Iteration: 300, Train loss: -0.4119, rewards: 0.4047\n",
      "Eval:\n",
      "Hits@1: 0.3998, Hits@3: 0.4496, Hits@10: 0.5092, MRR: 0.4331\n",
      "------------------------------------------------------------\n",
      "Iteration: 310, Train loss: -0.4279, rewards: 0.3576\n",
      "Iteration: 320, Train loss: -0.5005, rewards: 0.4084\n",
      "Iteration: 330, Train loss: -0.4985, rewards: 0.4149\n",
      "Iteration: 340, Train loss: -0.4261, rewards: 0.4031\n",
      "Iteration: 350, Train loss: -0.4661, rewards: 0.4383\n",
      "Iteration: 370, Train loss: -0.5187, rewards: 0.4054\n",
      "Iteration: 380, Train loss: -0.4893, rewards: 0.3905\n",
      "Iteration: 390, Train loss: -0.4964, rewards: 0.3921\n",
      "Iteration: 400, Train loss: -0.4863, rewards: 0.3929\n",
      "Eval:\n",
      "Hits@1: 0.3958, Hits@3: 0.4483, Hits@10: 0.5043, MRR: 0.4294\n",
      "------------------------------------------------------------\n",
      "Iteration: 410, Train loss: -0.4992, rewards: 0.4044\n",
      "Iteration: 420, Train loss: -0.4663, rewards: 0.3435\n",
      "Iteration: 430, Train loss: -0.5474, rewards: 0.4472\n",
      "Iteration: 440, Train loss: -0.4777, rewards: 0.3909\n",
      "Iteration: 450, Train loss: -0.4485, rewards: 0.4090\n",
      "Iteration: 460, Train loss: -0.4595, rewards: 0.4169\n",
      "Iteration: 470, Train loss: -0.5069, rewards: 0.3995\n",
      "Iteration: 480, Train loss: -0.4809, rewards: 0.4052\n",
      "Iteration: 490, Train loss: -0.3811, rewards: 0.3705\n",
      "Iteration: 500, Train loss: -0.3876, rewards: 0.4113\n",
      "Eval:\n",
      "Hits@1: 0.3629, Hits@3: 0.4489, Hits@10: 0.4980, MRR: 0.4119\n",
      "------------------------------------------------------------\n",
      "Iteration: 510, Train loss: -0.3428, rewards: 0.3845\n",
      "Iteration: 520, Train loss: -0.3852, rewards: 0.4043\n",
      "Iteration: 530, Train loss: -0.4498, rewards: 0.4128\n",
      "Iteration: 540, Train loss: -0.3956, rewards: 0.3726\n",
      "Iteration: 550, Train loss: -0.3405, rewards: 0.4035\n",
      "Iteration: 560, Train loss: -0.3314, rewards: 0.3798\n",
      "Iteration: 570, Train loss: -0.3460, rewards: 0.4084\n",
      "Iteration: 580, Train loss: -0.3879, rewards: 0.3811\n",
      "Iteration: 590, Train loss: -0.3160, rewards: 0.4133\n",
      "Iteration: 600, Train loss: -0.3662, rewards: 0.4053\n",
      "Eval:\n",
      "Hits@1: 0.3998, Hits@3: 0.4473, Hits@10: 0.4941, MRR: 0.4305\n",
      "------------------------------------------------------------\n",
      "Iteration: 610, Train loss: -0.3479, rewards: 0.4031\n",
      "Iteration: 620, Train loss: -0.3992, rewards: 0.3880\n",
      "Iteration: 630, Train loss: -0.4881, rewards: 0.4255\n",
      "Iteration: 640, Train loss: -0.3852, rewards: 0.4255\n",
      "Iteration: 650, Train loss: -0.4180, rewards: 0.3742\n",
      "Iteration: 660, Train loss: -0.3819, rewards: 0.3884\n",
      "Iteration: 670, Train loss: -0.3682, rewards: 0.4370\n",
      "Iteration: 680, Train loss: -0.4242, rewards: 0.4122\n",
      "Iteration: 690, Train loss: -0.4497, rewards: 0.3780\n",
      "Iteration: 700, Train loss: -0.4623, rewards: 0.3946\n",
      "Eval:\n",
      "Hits@1: 0.4051, Hits@3: 0.4450, Hits@10: 0.4924, MRR: 0.4318\n",
      "------------------------------------------------------------\n",
      "Iteration: 710, Train loss: -0.3534, rewards: 0.3983\n",
      "Iteration: 720, Train loss: -0.4039, rewards: 0.4024\n",
      "Iteration: 730, Train loss: -0.4451, rewards: 0.4350\n",
      "Iteration: 740, Train loss: -0.4785, rewards: 0.4158\n",
      "Iteration: 750, Train loss: -0.3986, rewards: 0.3991\n",
      "Iteration: 760, Train loss: -0.4055, rewards: 0.3891\n",
      "Iteration: 770, Train loss: -0.4020, rewards: 0.4507\n",
      "Iteration: 780, Train loss: -0.4015, rewards: 0.4144\n",
      "Iteration: 790, Train loss: -0.4413, rewards: 0.3844\n",
      "Iteration: 800, Train loss: -0.4612, rewards: 0.4050\n",
      "Eval:\n",
      "Hits@1: 0.4130, Hits@3: 0.4604, Hits@10: 0.5096, MRR: 0.4432\n",
      "------------------------------------------------------------\n",
      "Iteration: 810, Train loss: -0.4917, rewards: 0.3995\n",
      "Iteration: 820, Train loss: -0.4220, rewards: 0.4083\n",
      "Iteration: 830, Train loss: -0.4550, rewards: 0.3954\n",
      "Iteration: 840, Train loss: -0.4611, rewards: 0.4241\n",
      "Iteration: 850, Train loss: -0.3826, rewards: 0.4113\n",
      "Iteration: 860, Train loss: -0.4107, rewards: 0.3997\n",
      "Iteration: 870, Train loss: -0.4271, rewards: 0.3999\n",
      "Iteration: 880, Train loss: -0.4427, rewards: 0.3936\n",
      "Iteration: 890, Train loss: -0.4160, rewards: 0.4013\n",
      "Iteration: 900, Train loss: -0.4251, rewards: 0.4274\n",
      "Eval:\n",
      "Hits@1: 0.4140, Hits@3: 0.4562, Hits@10: 0.4947, MRR: 0.4408\n",
      "------------------------------------------------------------\n",
      "Iteration: 910, Train loss: -0.3710, rewards: 0.4045\n",
      "Iteration: 920, Train loss: -0.4145, rewards: 0.4171\n",
      "Iteration: 930, Train loss: -0.3640, rewards: 0.4251\n",
      "Iteration: 940, Train loss: -0.3500, rewards: 0.3995\n",
      "Iteration: 950, Train loss: -0.4346, rewards: 0.3962\n",
      "Iteration: 960, Train loss: -0.4276, rewards: 0.4066\n",
      "Iteration: 970, Train loss: -0.3689, rewards: 0.4039\n",
      "Iteration: 980, Train loss: -0.3486, rewards: 0.4086\n",
      "Iteration: 990, Train loss: -0.4001, rewards: 0.4208\n",
      "Iteration: 1000, Train loss: -0.4157, rewards: 0.4039\n",
      "Eval:\n",
      "Hits@1: 0.4140, Hits@3: 0.4595, Hits@10: 0.5040, MRR: 0.4429\n",
      "------------------------------------------------------------\n",
      "Iteration: 1010, Train loss: -0.3893, rewards: 0.4043\n",
      "Iteration: 1020, Train loss: -0.3953, rewards: 0.4013\n",
      "Iteration: 1030, Train loss: -0.3934, rewards: 0.4173\n",
      "Iteration: 1040, Train loss: -0.3869, rewards: 0.4030\n",
      "Iteration: 1050, Train loss: -0.3974, rewards: 0.4230\n",
      "Iteration: 1060, Train loss: -0.4107, rewards: 0.4159\n",
      "Iteration: 1070, Train loss: -0.4475, rewards: 0.3974\n",
      "Iteration: 1080, Train loss: -0.3989, rewards: 0.4112\n",
      "Iteration: 1090, Train loss: -0.4342, rewards: 0.4049\n",
      "Iteration: 1100, Train loss: -0.4320, rewards: 0.4070\n",
      "Eval:\n",
      "Hits@1: 0.4051, Hits@3: 0.4535, Hits@10: 0.4967, MRR: 0.4356\n",
      "------------------------------------------------------------\n",
      "Iteration: 1110, Train loss: -0.4533, rewards: 0.4537\n",
      "Iteration: 1120, Train loss: -0.3395, rewards: 0.4092\n",
      "Iteration: 1130, Train loss: -0.3680, rewards: 0.4061\n",
      "Iteration: 1140, Train loss: -0.3981, rewards: 0.4192\n",
      "Iteration: 1150, Train loss: -0.4006, rewards: 0.3914\n",
      "Iteration: 1160, Train loss: -0.3863, rewards: 0.3985\n",
      "Iteration: 1170, Train loss: -0.3784, rewards: 0.3910\n",
      "Iteration: 1180, Train loss: -0.3869, rewards: 0.3990\n",
      "Iteration: 1190, Train loss: -0.3863, rewards: 0.3576\n",
      "Iteration: 1200, Train loss: -0.4036, rewards: 0.4317\n",
      "Eval:\n",
      "Hits@1: 0.4087, Hits@3: 0.4604, Hits@10: 0.5066, MRR: 0.4410\n",
      "------------------------------------------------------------\n",
      "Iteration: 1210, Train loss: -0.3539, rewards: 0.3912\n",
      "Iteration: 1220, Train loss: -0.4327, rewards: 0.4166\n",
      "Iteration: 1230, Train loss: -0.3971, rewards: 0.4016\n",
      "Iteration: 1240, Train loss: -0.3551, rewards: 0.4048\n",
      "Iteration: 1250, Train loss: -0.4033, rewards: 0.4295\n",
      "Iteration: 1260, Train loss: -0.4074, rewards: 0.4263\n",
      "Iteration: 1270, Train loss: -0.4170, rewards: 0.4117\n",
      "Iteration: 1280, Train loss: -0.3884, rewards: 0.3922\n",
      "Iteration: 1290, Train loss: -0.4214, rewards: 0.4305\n",
      "Iteration: 1300, Train loss: -0.4187, rewards: 0.4151\n",
      "Eval:\n",
      "Hits@1: 0.4113, Hits@3: 0.4535, Hits@10: 0.4944, MRR: 0.4377\n",
      "------------------------------------------------------------\n",
      "Iteration: 1310, Train loss: -0.3895, rewards: 0.4185\n",
      "Iteration: 1320, Train loss: -0.3970, rewards: 0.3952\n",
      "Iteration: 1330, Train loss: -0.3695, rewards: 0.4102\n",
      "Iteration: 1340, Train loss: -0.3777, rewards: 0.4211\n",
      "Iteration: 1350, Train loss: -0.4358, rewards: 0.4260\n",
      "Iteration: 1360, Train loss: -0.4461, rewards: 0.3804\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1370, Train loss: -0.4531, rewards: 0.4199\n",
      "Iteration: 1380, Train loss: -0.3852, rewards: 0.4302\n",
      "Iteration: 1390, Train loss: -0.3093, rewards: 0.3885\n",
      "Iteration: 1400, Train loss: -0.3493, rewards: 0.4216\n",
      "Eval:\n",
      "Hits@1: 0.4183, Hits@3: 0.4657, Hits@10: 0.5049, MRR: 0.4473\n",
      "------------------------------------------------------------\n",
      "Iteration: 1410, Train loss: -0.3880, rewards: 0.4060\n",
      "Iteration: 1420, Train loss: -0.4159, rewards: 0.4343\n",
      "Iteration: 1430, Train loss: -0.4524, rewards: 0.3791\n",
      "Iteration: 1440, Train loss: -0.4111, rewards: 0.4047\n",
      "Iteration: 1450, Train loss: -0.4437, rewards: 0.3753\n",
      "Iteration: 1460, Train loss: -0.4284, rewards: 0.3815\n",
      "Iteration: 1470, Train loss: -0.4392, rewards: 0.3899\n",
      "Iteration: 1480, Train loss: -0.3913, rewards: 0.4091\n",
      "Iteration: 1490, Train loss: -0.4324, rewards: 0.4327\n",
      "Iteration: 1500, Train loss: -0.3955, rewards: 0.4027\n",
      "Eval:\n",
      "Hits@1: 0.4189, Hits@3: 0.4641, Hits@10: 0.5148, MRR: 0.4485\n",
      "------------------------------------------------------------\n",
      "Iteration: 1510, Train loss: -0.4539, rewards: 0.4131\n",
      "Iteration: 1520, Train loss: -0.4110, rewards: 0.4041\n",
      "Iteration: 1530, Train loss: -0.3854, rewards: 0.4496\n",
      "Iteration: 1540, Train loss: -0.3623, rewards: 0.3737\n",
      "Iteration: 1550, Train loss: -0.3655, rewards: 0.4060\n",
      "Iteration: 1560, Train loss: -0.3102, rewards: 0.4119\n",
      "Iteration: 1570, Train loss: -0.3541, rewards: 0.4088\n",
      "Iteration: 1580, Train loss: -0.4071, rewards: 0.4240\n",
      "Iteration: 1590, Train loss: -0.3971, rewards: 0.4240\n",
      "Iteration: 1600, Train loss: -0.4275, rewards: 0.3958\n",
      "Eval:\n",
      "Hits@1: 0.4100, Hits@3: 0.4568, Hits@10: 0.5053, MRR: 0.4403\n",
      "------------------------------------------------------------\n",
      "Iteration: 1610, Train loss: -0.3744, rewards: 0.4391\n",
      "Iteration: 1620, Train loss: -0.4548, rewards: 0.4503\n",
      "Iteration: 1630, Train loss: -0.4317, rewards: 0.3902\n",
      "Iteration: 1640, Train loss: -0.4250, rewards: 0.4257\n",
      "Iteration: 1650, Train loss: -0.4309, rewards: 0.3935\n",
      "Iteration: 1660, Train loss: -0.4292, rewards: 0.3946\n",
      "Iteration: 1670, Train loss: -0.3955, rewards: 0.4012\n",
      "Iteration: 1680, Train loss: -0.3995, rewards: 0.4115\n",
      "Iteration: 1690, Train loss: -0.3995, rewards: 0.4157\n",
      "Iteration: 1700, Train loss: -0.4033, rewards: 0.4516\n",
      "Eval:\n",
      "Hits@1: 0.4196, Hits@3: 0.4581, Hits@10: 0.5013, MRR: 0.4452\n",
      "------------------------------------------------------------\n",
      "Iteration: 1710, Train loss: -0.3683, rewards: 0.4027\n",
      "Iteration: 1720, Train loss: -0.3933, rewards: 0.3823\n",
      "Iteration: 1730, Train loss: -0.4133, rewards: 0.3816\n",
      "Iteration: 1740, Train loss: -0.4024, rewards: 0.4305\n",
      "Iteration: 1750, Train loss: -0.3873, rewards: 0.4052\n",
      "Iteration: 1760, Train loss: -0.3983, rewards: 0.4343\n",
      "Iteration: 1770, Train loss: -0.3373, rewards: 0.4034\n",
      "Iteration: 1780, Train loss: -0.3941, rewards: 0.4027\n",
      "Iteration: 1790, Train loss: -0.4425, rewards: 0.4008\n",
      "Iteration: 1800, Train loss: -0.4004, rewards: 0.3912\n",
      "Eval:\n",
      "Hits@1: 0.4117, Hits@3: 0.4611, Hits@10: 0.5049, MRR: 0.4423\n",
      "------------------------------------------------------------\n",
      "Iteration: 1810, Train loss: -0.4359, rewards: 0.4106\n",
      "Iteration: 1820, Train loss: -0.4169, rewards: 0.4019\n",
      "Iteration: 1830, Train loss: -0.4007, rewards: 0.4104\n",
      "Iteration: 1840, Train loss: -0.4053, rewards: 0.4111\n",
      "Iteration: 1850, Train loss: -0.3812, rewards: 0.4003\n",
      "Iteration: 1860, Train loss: -0.4279, rewards: 0.4184\n",
      "Iteration: 1870, Train loss: -0.3755, rewards: 0.4361\n",
      "Iteration: 1880, Train loss: -0.4152, rewards: 0.4284\n",
      "Iteration: 1890, Train loss: -0.3816, rewards: 0.4266\n",
      "Iteration: 1900, Train loss: -0.4100, rewards: 0.4134\n",
      "Eval:\n",
      "Hits@1: 0.4199, Hits@3: 0.4647, Hits@10: 0.5066, MRR: 0.4480\n",
      "------------------------------------------------------------\n",
      "Iteration: 1910, Train loss: -0.4411, rewards: 0.4368\n",
      "Iteration: 1920, Train loss: -0.4220, rewards: 0.4026\n",
      "Iteration: 1930, Train loss: -0.3598, rewards: 0.4065\n",
      "Iteration: 1940, Train loss: -0.3687, rewards: 0.3995\n",
      "Iteration: 1950, Train loss: -0.4041, rewards: 0.4266\n",
      "Iteration: 1960, Train loss: -0.4114, rewards: 0.4044\n",
      "Iteration: 1970, Train loss: -0.3776, rewards: 0.4066\n",
      "Iteration: 1980, Train loss: -0.3946, rewards: 0.4370\n",
      "Iteration: 1990, Train loss: -0.4066, rewards: 0.4032\n",
      "Iteration: 2000, Train loss: -0.4325, rewards: 0.4384\n",
      "Eval:\n",
      "Hits@1: 0.4117, Hits@3: 0.4572, Hits@10: 0.5000, MRR: 0.4404\n",
      "------------------------------------------------------------\n",
      "Iteration: 2010, Train loss: -0.3873, rewards: 0.4046\n",
      "Iteration: 2020, Train loss: -0.3987, rewards: 0.4037\n",
      "Iteration: 2030, Train loss: -0.4311, rewards: 0.4062\n",
      "Iteration: 2040, Train loss: -0.4128, rewards: 0.3673\n",
      "Iteration: 2050, Train loss: -0.4144, rewards: 0.4188\n",
      "Iteration: 2060, Train loss: -0.3720, rewards: 0.4003\n",
      "Iteration: 2070, Train loss: -0.3800, rewards: 0.4254\n",
      "Iteration: 2080, Train loss: -0.4369, rewards: 0.3818\n",
      "Iteration: 2090, Train loss: -0.4261, rewards: 0.4097\n",
      "Iteration: 2100, Train loss: -0.3966, rewards: 0.4266\n",
      "Eval:\n",
      "Hits@1: 0.4143, Hits@3: 0.4680, Hits@10: 0.5115, MRR: 0.4468\n",
      "------------------------------------------------------------\n",
      "Iteration: 2110, Train loss: -0.4203, rewards: 0.4396\n",
      "Iteration: 2120, Train loss: -0.4040, rewards: 0.4298\n",
      "Iteration: 2130, Train loss: -0.3751, rewards: 0.4238\n",
      "Iteration: 2140, Train loss: -0.3806, rewards: 0.4469\n",
      "Iteration: 2150, Train loss: -0.4159, rewards: 0.4063\n",
      "Iteration: 2160, Train loss: -0.3697, rewards: 0.4252\n",
      "Iteration: 2170, Train loss: -0.3562, rewards: 0.4216\n",
      "Iteration: 2180, Train loss: -0.3489, rewards: 0.4262\n",
      "Iteration: 2190, Train loss: -0.3968, rewards: 0.4184\n",
      "Iteration: 2200, Train loss: -0.4017, rewards: 0.4037\n",
      "Eval:\n",
      "Hits@1: 0.4196, Hits@3: 0.4707, Hits@10: 0.5112, MRR: 0.4501\n",
      "------------------------------------------------------------\n",
      "Iteration: 2210, Train loss: -0.3667, rewards: 0.4207\n",
      "Iteration: 2220, Train loss: -0.3584, rewards: 0.3767\n",
      "Iteration: 2230, Train loss: -0.3350, rewards: 0.4238\n",
      "Iteration: 2240, Train loss: -0.3643, rewards: 0.4245\n",
      "Iteration: 2250, Train loss: -0.4029, rewards: 0.3995\n",
      "Iteration: 2260, Train loss: -0.4062, rewards: 0.4157\n",
      "Iteration: 2270, Train loss: -0.4105, rewards: 0.4256\n",
      "Iteration: 2280, Train loss: -0.3955, rewards: 0.4123\n",
      "Iteration: 2290, Train loss: -0.4393, rewards: 0.3916\n",
      "Iteration: 2300, Train loss: -0.3885, rewards: 0.4320\n",
      "Eval:\n",
      "Hits@1: 0.4216, Hits@3: 0.4664, Hits@10: 0.5096, MRR: 0.4500\n",
      "------------------------------------------------------------\n",
      "Iteration: 2310, Train loss: -0.4325, rewards: 0.4245\n",
      "Iteration: 2320, Train loss: -0.4159, rewards: 0.4430\n",
      "Iteration: 2330, Train loss: -0.3972, rewards: 0.4179\n",
      "Iteration: 2340, Train loss: -0.3938, rewards: 0.4080\n",
      "Iteration: 2350, Train loss: -0.3674, rewards: 0.4341\n",
      "Iteration: 2360, Train loss: -0.4126, rewards: 0.4152\n",
      "Iteration: 2370, Train loss: -0.3950, rewards: 0.4020\n",
      "Iteration: 2380, Train loss: -0.4007, rewards: 0.4351\n",
      "Iteration: 2390, Train loss: -0.3395, rewards: 0.4073\n",
      "Iteration: 2400, Train loss: -0.3633, rewards: 0.3925\n",
      "Eval:\n",
      "Hits@1: 0.4183, Hits@3: 0.4674, Hits@10: 0.5115, MRR: 0.4482\n",
      "------------------------------------------------------------\n",
      "Iteration: 2410, Train loss: -0.3509, rewards: 0.4157\n",
      "Iteration: 2420, Train loss: -0.3362, rewards: 0.4358\n",
      "Iteration: 2430, Train loss: -0.4059, rewards: 0.4139\n",
      "Iteration: 2440, Train loss: -0.4209, rewards: 0.4042\n",
      "Iteration: 2450, Train loss: -0.3986, rewards: 0.4421\n",
      "Iteration: 2460, Train loss: -0.3948, rewards: 0.3850\n",
      "Iteration: 2470, Train loss: -0.4181, rewards: 0.4361\n",
      "Iteration: 2480, Train loss: -0.3701, rewards: 0.4366\n",
      "Iteration: 2490, Train loss: -0.3339, rewards: 0.4280\n",
      "Iteration: 2500, Train loss: -0.3542, rewards: 0.4365\n",
      "Eval:\n",
      "Hits@1: 0.4156, Hits@3: 0.4598, Hits@10: 0.5096, MRR: 0.4455\n",
      "------------------------------------------------------------\n",
      "Iteration: 2510, Train loss: -0.3631, rewards: 0.4026\n",
      "Iteration: 2520, Train loss: -0.4199, rewards: 0.4275\n",
      "Iteration: 2530, Train loss: -0.3891, rewards: 0.4127\n",
      "Iteration: 2540, Train loss: -0.3833, rewards: 0.4040\n",
      "Iteration: 2550, Train loss: -0.4327, rewards: 0.4084\n",
      "Iteration: 2560, Train loss: -0.4201, rewards: 0.4137\n",
      "Iteration: 2570, Train loss: -0.4141, rewards: 0.4230\n",
      "Iteration: 2580, Train loss: -0.3837, rewards: 0.4284\n",
      "Iteration: 2590, Train loss: -0.3823, rewards: 0.4243\n",
      "Iteration: 2600, Train loss: -0.3511, rewards: 0.4023\n",
      "Eval:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.4192, Hits@3: 0.4690, Hits@10: 0.5089, MRR: 0.4501\n",
      "------------------------------------------------------------\n",
      "Iteration: 2610, Train loss: -0.3910, rewards: 0.4323\n",
      "Iteration: 2620, Train loss: -0.4098, rewards: 0.3827\n",
      "Iteration: 2630, Train loss: -0.4285, rewards: 0.4384\n",
      "Iteration: 2640, Train loss: -0.3845, rewards: 0.4534\n",
      "Iteration: 2650, Train loss: -0.4172, rewards: 0.3952\n",
      "Iteration: 2660, Train loss: -0.3983, rewards: 0.4336\n",
      "Iteration: 2670, Train loss: -0.4221, rewards: 0.4440\n",
      "Iteration: 2680, Train loss: -0.3523, rewards: 0.4192\n",
      "Iteration: 2690, Train loss: -0.3670, rewards: 0.4012\n",
      "Iteration: 2700, Train loss: -0.3682, rewards: 0.3865\n",
      "Eval:\n",
      "Hits@1: 0.4133, Hits@3: 0.4670, Hits@10: 0.5105, MRR: 0.4467\n",
      "------------------------------------------------------------\n",
      "Iteration: 2710, Train loss: -0.3788, rewards: 0.4238\n",
      "Iteration: 2720, Train loss: -0.4218, rewards: 0.4144\n",
      "Iteration: 2730, Train loss: -0.4322, rewards: 0.4162\n",
      "Iteration: 2740, Train loss: -0.3951, rewards: 0.3970\n",
      "Iteration: 2750, Train loss: -0.4269, rewards: 0.3998\n",
      "Iteration: 2760, Train loss: -0.3954, rewards: 0.4182\n",
      "Iteration: 2770, Train loss: -0.3981, rewards: 0.4191\n",
      "Iteration: 2780, Train loss: -0.3990, rewards: 0.4078\n",
      "Iteration: 2790, Train loss: -0.3206, rewards: 0.4572\n",
      "Iteration: 2800, Train loss: -0.3583, rewards: 0.4148\n",
      "Eval:\n",
      "Hits@1: 0.4160, Hits@3: 0.4703, Hits@10: 0.5115, MRR: 0.4488\n",
      "------------------------------------------------------------\n",
      "Iteration: 2810, Train loss: -0.3754, rewards: 0.4246\n",
      "Iteration: 2820, Train loss: -0.3838, rewards: 0.4305\n",
      "Iteration: 2830, Train loss: -0.3706, rewards: 0.4183\n",
      "Iteration: 2840, Train loss: -0.3854, rewards: 0.4071\n",
      "Iteration: 2850, Train loss: -0.3651, rewards: 0.4114\n",
      "Iteration: 2860, Train loss: -0.3768, rewards: 0.4215\n",
      "Iteration: 2870, Train loss: -0.3976, rewards: 0.4315\n",
      "Iteration: 2880, Train loss: -0.3871, rewards: 0.4430\n",
      "Iteration: 2890, Train loss: -0.3610, rewards: 0.3855\n",
      "Iteration: 2900, Train loss: -0.3799, rewards: 0.4139\n",
      "Eval:\n",
      "Hits@1: 0.4117, Hits@3: 0.4661, Hits@10: 0.5099, MRR: 0.4444\n",
      "------------------------------------------------------------\n",
      "Iteration: 2910, Train loss: -0.3449, rewards: 0.4392\n",
      "Iteration: 2920, Train loss: -0.3769, rewards: 0.4320\n",
      "Iteration: 2930, Train loss: -0.3725, rewards: 0.4085\n",
      "Iteration: 2940, Train loss: -0.3347, rewards: 0.4361\n",
      "Iteration: 2950, Train loss: -0.3631, rewards: 0.4459\n",
      "Iteration: 2960, Train loss: -0.3873, rewards: 0.4435\n",
      "Iteration: 2970, Train loss: -0.3137, rewards: 0.4437\n",
      "Iteration: 2980, Train loss: -0.3261, rewards: 0.4280\n",
      "Iteration: 2990, Train loss: -0.3389, rewards: 0.4022\n",
      "Iteration: 3000, Train loss: -0.3728, rewards: 0.4160\n",
      "Eval:\n",
      "Hits@1: 0.4127, Hits@3: 0.4555, Hits@10: 0.5043, MRR: 0.4417\n",
      "------------------------------------------------------------\n",
      "Iteration: 3010, Train loss: -0.3705, rewards: 0.4340\n",
      "Iteration: 3020, Train loss: -0.3798, rewards: 0.4139\n",
      "Iteration: 3030, Train loss: -0.4028, rewards: 0.4352\n",
      "Iteration: 3040, Train loss: -0.3367, rewards: 0.3750\n",
      "Iteration: 3050, Train loss: -0.3746, rewards: 0.4328\n",
      "Iteration: 3060, Train loss: -0.3420, rewards: 0.4526\n",
      "Iteration: 3070, Train loss: -0.4046, rewards: 0.4269\n",
      "Iteration: 3080, Train loss: -0.4133, rewards: 0.4269\n",
      "Iteration: 3090, Train loss: -0.3799, rewards: 0.4029\n",
      "Iteration: 3100, Train loss: -0.3902, rewards: 0.4382\n",
      "Eval:\n",
      "Hits@1: 0.4239, Hits@3: 0.4710, Hits@10: 0.5175, MRR: 0.4544\n",
      "------------------------------------------------------------\n",
      "Iteration: 3110, Train loss: -0.3278, rewards: 0.4687\n",
      "Iteration: 3120, Train loss: -0.3595, rewards: 0.4141\n",
      "Iteration: 3130, Train loss: -0.3854, rewards: 0.4227\n",
      "Iteration: 3140, Train loss: -0.3649, rewards: 0.3981\n",
      "Iteration: 3150, Train loss: -0.3791, rewards: 0.4325\n",
      "Iteration: 3160, Train loss: -0.4172, rewards: 0.4127\n",
      "Iteration: 3170, Train loss: -0.3996, rewards: 0.4020\n",
      "Iteration: 3180, Train loss: -0.3628, rewards: 0.4143\n",
      "Iteration: 3190, Train loss: -0.3804, rewards: 0.4354\n",
      "Iteration: 3200, Train loss: -0.3508, rewards: 0.4113\n",
      "Eval:\n",
      "Hits@1: 0.4199, Hits@3: 0.4644, Hits@10: 0.5059, MRR: 0.4484\n",
      "------------------------------------------------------------\n",
      "Iteration: 3210, Train loss: -0.3391, rewards: 0.4275\n",
      "Iteration: 3220, Train loss: -0.3672, rewards: 0.4037\n",
      "Iteration: 3230, Train loss: -0.3754, rewards: 0.4266\n",
      "Iteration: 3240, Train loss: -0.3978, rewards: 0.4319\n",
      "Iteration: 3250, Train loss: -0.4028, rewards: 0.3852\n",
      "Iteration: 3260, Train loss: -0.4153, rewards: 0.4376\n",
      "Iteration: 3270, Train loss: -0.3997, rewards: 0.3935\n",
      "Iteration: 3280, Train loss: -0.4243, rewards: 0.4328\n",
      "Iteration: 3290, Train loss: -0.3817, rewards: 0.4383\n",
      "Iteration: 3300, Train loss: -0.3480, rewards: 0.4448\n",
      "Eval:\n",
      "Hits@1: 0.4094, Hits@3: 0.4578, Hits@10: 0.5086, MRR: 0.4403\n",
      "------------------------------------------------------------\n",
      "Iteration: 3310, Train loss: -0.3702, rewards: 0.4297\n",
      "Iteration: 3320, Train loss: -0.3730, rewards: 0.4146\n",
      "Iteration: 3330, Train loss: -0.3828, rewards: 0.4806\n",
      "Iteration: 3340, Train loss: -0.3445, rewards: 0.4222\n",
      "Iteration: 3350, Train loss: -0.3651, rewards: 0.4330\n",
      "Iteration: 3360, Train loss: -0.3715, rewards: 0.3904\n",
      "Iteration: 3370, Train loss: -0.3504, rewards: 0.3970\n",
      "Iteration: 3380, Train loss: -0.3436, rewards: 0.4114\n",
      "Iteration: 3390, Train loss: -0.3442, rewards: 0.4205\n",
      "Iteration: 3400, Train loss: -0.3587, rewards: 0.4200\n",
      "Eval:\n",
      "Hits@1: 0.4199, Hits@3: 0.4684, Hits@10: 0.5142, MRR: 0.4504\n",
      "------------------------------------------------------------\n",
      "Iteration: 3410, Train loss: -0.3925, rewards: 0.4001\n",
      "Iteration: 3420, Train loss: -0.3284, rewards: 0.3902\n",
      "Iteration: 3430, Train loss: -0.3737, rewards: 0.4144\n",
      "Iteration: 3440, Train loss: -0.3312, rewards: 0.3866\n",
      "Iteration: 3450, Train loss: -0.3298, rewards: 0.4174\n",
      "Iteration: 3460, Train loss: -0.3150, rewards: 0.4081\n",
      "Iteration: 3470, Train loss: -0.3388, rewards: 0.4002\n",
      "Iteration: 3480, Train loss: -0.3762, rewards: 0.4104\n",
      "Iteration: 3490, Train loss: -0.3390, rewards: 0.4261\n",
      "Iteration: 3500, Train loss: -0.3636, rewards: 0.4182\n",
      "Eval:\n",
      "Hits@1: 0.4173, Hits@3: 0.4693, Hits@10: 0.5138, MRR: 0.4485\n",
      "------------------------------------------------------------\n",
      "Iteration: 3510, Train loss: -0.3760, rewards: 0.4078\n",
      "Iteration: 3520, Train loss: -0.4089, rewards: 0.4174\n",
      "Iteration: 3530, Train loss: -0.4115, rewards: 0.4349\n",
      "Iteration: 3540, Train loss: -0.3948, rewards: 0.3999\n",
      "Iteration: 3550, Train loss: -0.3928, rewards: 0.4272\n",
      "Iteration: 3560, Train loss: -0.4260, rewards: 0.4334\n",
      "Iteration: 3570, Train loss: -0.4031, rewards: 0.4288\n",
      "Iteration: 3580, Train loss: -0.3857, rewards: 0.4351\n",
      "Iteration: 3590, Train loss: -0.3787, rewards: 0.4112\n",
      "Iteration: 3600, Train loss: -0.3442, rewards: 0.3985\n",
      "Eval:\n",
      "Hits@1: 0.4189, Hits@3: 0.4667, Hits@10: 0.5162, MRR: 0.4504\n",
      "------------------------------------------------------------\n",
      "Iteration: 3610, Train loss: -0.3826, rewards: 0.4353\n",
      "Iteration: 3620, Train loss: -0.3722, rewards: 0.4131\n",
      "Iteration: 3630, Train loss: -0.3929, rewards: 0.4000\n",
      "Iteration: 3640, Train loss: -0.3626, rewards: 0.4456\n",
      "Iteration: 3650, Train loss: -0.3796, rewards: 0.4365\n",
      "Iteration: 3660, Train loss: -0.3460, rewards: 0.4504\n",
      "Iteration: 3670, Train loss: -0.3340, rewards: 0.4170\n",
      "Iteration: 3680, Train loss: -0.3249, rewards: 0.4620\n",
      "Iteration: 3690, Train loss: -0.3565, rewards: 0.4185\n",
      "Iteration: 3700, Train loss: -0.4025, rewards: 0.4366\n",
      "Eval:\n",
      "Hits@1: 0.4258, Hits@3: 0.4723, Hits@10: 0.5138, MRR: 0.4550\n",
      "------------------------------------------------------------\n",
      "Iteration: 3710, Train loss: -0.3054, rewards: 0.3680\n",
      "Iteration: 3720, Train loss: -0.3681, rewards: 0.4455\n",
      "Iteration: 3730, Train loss: -0.4184, rewards: 0.4360\n",
      "Iteration: 3740, Train loss: -0.5117, rewards: 0.4323\n",
      "Iteration: 3750, Train loss: -0.4468, rewards: 0.4455\n",
      "Iteration: 3760, Train loss: -0.3465, rewards: 0.4072\n",
      "Iteration: 3770, Train loss: -0.3348, rewards: 0.4591\n",
      "Iteration: 3780, Train loss: -0.3559, rewards: 0.4098\n",
      "Iteration: 3790, Train loss: -0.3948, rewards: 0.4121\n",
      "Iteration: 3800, Train loss: -0.3921, rewards: 0.4521\n",
      "Eval:\n",
      "Hits@1: 0.4176, Hits@3: 0.4634, Hits@10: 0.5059, MRR: 0.4467\n",
      "------------------------------------------------------------\n",
      "Iteration: 3810, Train loss: -0.3556, rewards: 0.4147\n",
      "Iteration: 3820, Train loss: -0.3852, rewards: 0.4097\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 3830, Train loss: -0.4163, rewards: 0.3995\n",
      "Iteration: 3840, Train loss: -0.3708, rewards: 0.4103\n",
      "Iteration: 3850, Train loss: -0.3476, rewards: 0.4245\n",
      "Iteration: 3860, Train loss: -0.3507, rewards: 0.4083\n",
      "Iteration: 3870, Train loss: -0.3852, rewards: 0.4141\n",
      "Iteration: 3880, Train loss: -0.4028, rewards: 0.4156\n",
      "Iteration: 3890, Train loss: -0.3676, rewards: 0.3963\n",
      "Iteration: 3900, Train loss: -0.3861, rewards: 0.4187\n",
      "Eval:\n",
      "Hits@1: 0.4169, Hits@3: 0.4634, Hits@10: 0.5020, MRR: 0.4455\n",
      "------------------------------------------------------------\n",
      "Iteration: 3910, Train loss: -0.3670, rewards: 0.4249\n",
      "Iteration: 3920, Train loss: -0.3642, rewards: 0.4363\n",
      "Iteration: 3930, Train loss: -0.3780, rewards: 0.4145\n",
      "Iteration: 3940, Train loss: -0.3446, rewards: 0.4130\n",
      "Iteration: 3950, Train loss: -0.3904, rewards: 0.4398\n",
      "Iteration: 3960, Train loss: -0.3861, rewards: 0.4289\n",
      "Iteration: 3970, Train loss: -0.3569, rewards: 0.4077\n",
      "Iteration: 3980, Train loss: -0.3795, rewards: 0.4472\n",
      "Iteration: 3990, Train loss: -0.3804, rewards: 0.4050\n",
      "Iteration: 4000, Train loss: -0.3986, rewards: 0.4154\n",
      "Eval:\n",
      "Hits@1: 0.4183, Hits@3: 0.4680, Hits@10: 0.5132, MRR: 0.4493\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(options)\n",
    "#self = trainer\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6c2ad0c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1595732/1090773659.py:224: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n",
      "/tmp/ipykernel_1595732/1123405777.py:327: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.4343, Hits@3: 0.4786, Hits@10: 0.5255, MRR: 0.4636\n"
     ]
    }
   ],
   "source": [
    "trainer.agent.load_state_dict(torch.load(options['model_dir'] + 'agent.ckpt'))\n",
    "trainer.agent.eval()\n",
    "\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cc9340fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print_paths = True\n",
    "# with torch.no_grad():\n",
    "\n",
    "#     batch_counter = 0\n",
    "#     paths = defaultdict(list)\n",
    "#     answers = []\n",
    "#     all_final_reward_1 = 0\n",
    "#     all_final_reward_3 = 0\n",
    "#     all_final_reward_5 = 0\n",
    "#     all_final_reward_10 = 0\n",
    "#     all_final_reward_20 = 0\n",
    "#     auc = 0\n",
    "\n",
    "#     total_examples = self.test_environment.total_no_examples\n",
    "\n",
    "#     for entity_episode in self.test_environment.get_episodes():\n",
    "#         batch_counter += 1\n",
    "\n",
    "#         temp_batch_size = entity_episode.no_examples\n",
    "\n",
    "#         self.qr = entity_episode.get_query_relation()\n",
    "#         query_relation = self.qr\n",
    "#         query_relation = torch.tensor(query_relation).long().to(self.device)\n",
    "#         # set initial beam probs\n",
    "#         beam_probs = torch.zeros((temp_batch_size * self.test_rollouts, 1)).to(self.device)\n",
    "\n",
    "#         # get initial state for entity agent\n",
    "#         entity_state = entity_episode.get_state()\n",
    "\n",
    "#         head, link, tail = self.get_graph(entity_episode, mode = 'test')\n",
    "#         next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "#         next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "#         current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "\n",
    "#         entity_state_emb = torch.zeros(1, 2, temp_batch_size * self.test_rollouts,\n",
    "#                                        self.agent.m * self.hidden_size).to(self.device)\n",
    "#         prev_relation = (torch.ones(temp_batch_size * self.test_rollouts) * self.relation_vocab[\n",
    "#             'DUMMY_START_RELATION']).long().to(self.device)\n",
    "\n",
    "#         if print_paths:\n",
    "#             self.entity_trajectory = [current_entities]\n",
    "#             self.relation_trajectory = [prev_relation]   \n",
    "\n",
    "#         self.log_probs = np.zeros((temp_batch_size * self.test_rollouts,)) * 1.0\n",
    "#         for i in range(self.path_length):\n",
    "#             break\n",
    "#         break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8f512a9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# next_relations, next_entities, prev_state = next_relations, next_entities, entity_state_emb\n",
    "# prev_relation, query_embedding, current_entities = prev_relation, query_relation,current_entities\n",
    "# head, link, tail = head, link, tail"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "70c5dc8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with torch.no_grad():\n",
    "    \n",
    "#     n_ent = len(self.agent.entity_embedding.weight)\n",
    "#     n_rel = len(self.agent.relation_embedding.weight)\n",
    "    \n",
    "#     self.agent.ent_embed = self.agent.entity_embedding(torch.arange(n_ent).to(head.device))\n",
    "#     self.agent.rel_embed = self.agent.relation_embedding(torch.arange(n_rel).to(link.device))\n",
    "\n",
    "#     for i in range(2):\n",
    "#         self.agent.ent_embed = self.agent.node_conv[i](head, link, tail, \n",
    "#                                                 self.agent.ent_embed, self.agent.rel_embed)\n",
    "#         self.agent.rel_embed = self.agent.rel_conv[i](head, link, tail, \n",
    "#                                                 self.agent.ent_embed, self.agent.rel_embed)\n",
    "\n",
    "#     prev_action_embedding = self.agent.action_encoder(prev_relation, current_entities) # (original batch_size * num_rollout, 4*self.agent.embedding_size)\n",
    "\n",
    "#     prev_state = torch.unbind(prev_state, dim=1)\n",
    "#     prev_state = [prev_state[0].squeeze(0), prev_state[1].squeeze(0)]\n",
    "\n",
    "#     new_prev_state = list()\n",
    "\n",
    "#     output, new_state = self.agent.policy_step(prev_action_embedding, prev_state)\n",
    "\n",
    "#     prev_entity = self.agent.ent_embed[current_entities]\n",
    "#     if self.agent.use_entity_embeddings:\n",
    "#         state = torch.cat([output, prev_entity], dim=-1)\n",
    "#     else:\n",
    "#         state = output\n",
    "\n",
    "#     candidate_action_embeddings = self.agent.action_encoder(next_relations, next_entities)\n",
    "#     query_embedding = self.agent.rel_embed[query_embedding]\n",
    "\n",
    "#     state_query_concat = torch.cat([state, query_embedding], dim=-1)\n",
    "\n",
    "#     # MLP for policy#\n",
    "\n",
    "#     output = self.agent.policy_mlp(state_query_concat)\n",
    "#     # print(output.size())\n",
    "#     output_expanded = torch.unsqueeze(output, dim=1)  # [original batch_size * num_rollout, 1, 2D], D=self.agent.hidden_size\n",
    "#     # print(output_expanded.size(), candidate_action_embeddings.size())\n",
    "#     prelim_scores = torch.sum(candidate_action_embeddings * output_expanded, dim=2)\n",
    "\n",
    "#     # Masking PAD actions\n",
    "\n",
    "#     comparison_tensor = torch.ones_like(next_relations).int() * self.agent.rPAD  # matrix to compare\n",
    "#     mask = next_relations == comparison_tensor  # The mask\n",
    "#     dummy_scores = torch.ones_like(prelim_scores) * -99999.0  # the base matrix to choose from if dummy relation\n",
    "#     scores = torch.where(mask, dummy_scores, prelim_scores)  # [original batch_size * num_rollout, max_num_actions]\n",
    "\n",
    "#     # 4 sample action\n",
    "#     action = torch.distributions.categorical.Categorical(logits=scores) # [original batch_size * num_rollout, 1]\n",
    "#     label_action = action.sample() # [original batch_size * num_rollout,]\n",
    "\n",
    "#     # loss\n",
    "#     # 5a.\n",
    "#     loss = torch.nn.CrossEntropyLoss(reduce=False)(scores, label_action)\n",
    "\n",
    "#     # 6. Map back to true id\n",
    "#     chosen_relation = next_relations[torch.arange(len(label_action)), label_action]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "70c874d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# next_relations, next_entities, prev_state = next_possible_relations, next_possible_entities, entity_state_emb\n",
    "# prev_relation, query_embedding, current_entities = prev_relation, query_relation,current_entities\n",
    "# head, link, tail = head, link, tail"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "dc1d1658",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with torch.no_grad():\n",
    "#     self.agent.ent_embed = self.agent.entity_embedding(torch.arange(head.max().item() + 1).to(head.device))\n",
    "#     self.agent.rel_embed = self.agent.relation_embedding(torch.arange(link.max().item() + 1).to(link.device))\n",
    "\n",
    "#     for i in range(2):\n",
    "#         self.agent.ent_embed = self.agent.node_conv[i](head, link, tail, \n",
    "#                                                 self.agent.ent_embed, self.agent.rel_embed)\n",
    "#         self.agent.rel_embed = self.agent.rel_conv[i](head, link, tail, \n",
    "#                                                 self.agent.ent_embed, self.agent.rel_embed)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e0aa96cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with torch.no_grad():\n",
    "#     prev_action_embedding = self.agent.action_encoder(prev_relation, current_entities) # (original batch_size * num_rollout, 4*self.embedding_size)\n",
    "\n",
    "#     prev_state = torch.unbind(prev_state, dim=1)\n",
    "#     prev_state = [prev_state[0].squeeze(0), prev_state[1].squeeze(0)]\n",
    "\n",
    "#     new_prev_state = list()\n",
    "\n",
    "#     output, new_state = self.agent.policy_step(prev_action_embedding, prev_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8763062b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with torch.no_grad():\n",
    "#     prev_entity = self.agent.ent_embed[current_entities]\n",
    "#     if self.agent.use_entity_embeddings:\n",
    "#         state = torch.cat([output, prev_entity], dim=-1)\n",
    "#     else:\n",
    "#         state = output\n",
    "\n",
    "#     candidate_action_embeddings = self.agent.action_encoder(next_relations, next_entities)\n",
    "#     query_embedding = self.agent.rel_embed[query_embedding]\n",
    "\n",
    "#     state_query_concat = torch.cat([state, query_embedding], dim=-1)\n",
    "\n",
    "#     # MLP for policy#\n",
    "\n",
    "#     output = self.agent.policy_mlp(state_query_concat)\n",
    "#     # print(output.size())\n",
    "#     output_expanded = torch.unsqueeze(output, dim=1)  # [original batch_size * num_rollout, 1, 2D], D=self.agent.hidden_size\n",
    "#     # print(output_expanded.size(), candidate_action_embeddings.size())\n",
    "#     prelim_scores = torch.sum(candidate_action_embeddings * output_expanded, dim=2)\n",
    "\n",
    "#     # Masking PAD actions\n",
    "\n",
    "#     comparison_tensor = torch.ones_like(next_relations).int() * self.agent.rPAD  # matrix to compare\n",
    "#     mask = next_relations == comparison_tensor  # The mask\n",
    "#     dummy_scores = torch.ones_like(prelim_scores) * -99999.0  # the base matrix to choose from if dummy relation\n",
    "#     scores = torch.where(mask, dummy_scores, prelim_scores)  # [original batch_size * num_rollout, max_num_actions]\n",
    "\n",
    "#     # 4 sample action\n",
    "#     action = torch.distributions.categorical.Categorical(logits=scores) # [original batch_size * num_rollout, 1]\n",
    "#     label_action = action.sample() # [original batch_size * num_rollout,]\n",
    "\n",
    "#     # loss\n",
    "#     # 5a.\n",
    "#     loss = torch.nn.CrossEntropyLoss(reduce=False)(scores, label_action)\n",
    "\n",
    "#     # 6. Map back to true id\n",
    "#     chosen_relation = next_relations[torch.arange(len(label_action)), label_action]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
