{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n",
    "\n",
    "from model.trainer import Trainer\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2e28869",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from tqdm.notebook import tqdm\n",
    "import json\n",
    "import time\n",
    "import os\n",
    "import logging\n",
    "import numpy as np\n",
    "from model.agent import Agent\n",
    "from model.options import read_options\n",
    "from model.environment import env\n",
    "import codecs\n",
    "from collections import defaultdict\n",
    "import gc\n",
    "import resource\n",
    "import sys\n",
    "from model.baseline import ReactiveBaseline\n",
    "from scipy.special import logsumexp as lse\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam, SGD, AdamW\n",
    "from copy import deepcopy\n",
    "from model.nell_eval import nell_eval\n",
    "\n",
    "def get_logger(output_dir):\n",
    "    filename=output_dir +'train'\n",
    "    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter\n",
    "    logger = getLogger(__name__)\n",
    "    logger.setLevel(INFO)\n",
    "    handler1 = StreamHandler()\n",
    "    handler1.setFormatter(Formatter(\"%(message)s\"))\n",
    "    handler2 = FileHandler(filename=f\"{filename}.log\")\n",
    "    handler2.setFormatter(Formatter(\"%(message)s\"))\n",
    "    logger.addHandler(handler1)\n",
    "    logger.addHandler(handler2)\n",
    "    return logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "865c2f4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.utils as utils\n",
    "\n",
    "class Agent(nn.Module):\n",
    "\n",
    "    def __init__(self, params):\n",
    "        super(Agent, self).__init__()\n",
    "        \n",
    "        #agent setting\n",
    "        self.action_vocab_size = len(params['relation_vocab'])\n",
    "        self.entity_vocab_size = len(params['entity_vocab'])\n",
    "        self.embedding_size = params['embedding_size']\n",
    "        self.hidden_size = params['hidden_size']\n",
    "        self.ePAD = params['entity_vocab']['PAD']\n",
    "        self.rPAD = params['relation_vocab']['PAD']\n",
    "        self.use_entity_embeddings = params['use_entity_embeddings']\n",
    "        self.train_entity_embeddings = params['train_entity_embeddings']\n",
    "        self.train_relation_embeddings = params['train_relation_embeddings']\n",
    "        self.device = params['device']\n",
    "    \n",
    "        self.num_rollouts = params['num_rollouts']\n",
    "        self.test_rollouts = params['test_rollouts']\n",
    "        self.LSTM_Layers = params['LSTM_layers']\n",
    "        self.batch_size = params['batch_size'] * params['num_rollouts']\n",
    "        self.dummy_start_label = (torch.ones(self.batch_size) * params['relation_vocab']['DUMMY_START_RELATION']).long()\n",
    "        self.entity_embedding_size = self.embedding_size\n",
    "        \n",
    "        #entity and relation embedding\n",
    "        if self.use_entity_embeddings:\n",
    "            if self.train_entity_embeddings:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size)\n",
    "            else:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                    False)\n",
    "            torch.nn.init.xavier_uniform_(self.entity_embedding.weight)\n",
    "        else:\n",
    "            if self.train_entity_embeddings:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size)\n",
    "            else:\n",
    "                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                    False)\n",
    "            torch.nn.init.constant_(self.entity_embedding.weight, 0.0)\n",
    "\n",
    "        if self.train_relation_embeddings:\n",
    "            self.relation_embedding = nn.Embedding(self.action_vocab_size, 2 * self.embedding_size)\n",
    "        else:\n",
    "            self.relation_embedding = nn.Embedding(self.action_vocab_size, 2 * self.embedding_size).requires_grad_(\n",
    "                False)\n",
    "        torch.nn.init.xavier_uniform_(self.relation_embedding.weight)\n",
    "\n",
    "        #operators\n",
    "        \n",
    "        self.embedding_encoder = nn.Sequential(nn.Linear(3*2*self.embedding_size, 2*self.embedding_size), nn.Mish(),\n",
    "                                    nn.Linear(2*self.embedding_size,  2*self.embedding_size), \n",
    "                                               nn.LayerNorm(2*self.embedding_size))\n",
    "        \n",
    "        self.n_head = 4\n",
    "        self.sequence_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model = 2*self.embedding_size,\n",
    "                                                                                 nhead = self.n_head,\n",
    "                                                                                 dim_feedforward = 2*self.embedding_size*4,\n",
    "                                                                                 dropout = 0,\n",
    "                                                                                 batch_first = True), \n",
    "                                                      num_layers = 2)   \n",
    "        \n",
    "        self.output_layer = nn.Linear(2*self.embedding_size, 1)\n",
    "        \n",
    "    \n",
    "    def get_scores(self, next_relations, next_entities, path_relation, path_entities, query_relation):\n",
    "        l = len(path_relation)\n",
    "        n = next_relations.shape[-1]\n",
    "\n",
    "        hist_relation = torch.cat([x.unsqueeze(1) for x in path_relation], -1)\n",
    "        hist_entities = torch.cat([x.unsqueeze(1) for x in path_entities], -1)\n",
    "\n",
    "        hist_relation_embed = self.relation_embedding(hist_relation.long())\n",
    "        hist_entities_embed = self.entity_embedding(hist_entities.long())\n",
    "\n",
    "        candidate_relation_embed = self.relation_embedding(next_relations.long())\n",
    "        candidate_entities_embed = self.entity_embedding(next_entities.long())\n",
    "\n",
    "        query_embed = self.relation_embedding(query_relation.long()).unsqueeze(1)    \n",
    "\n",
    "\n",
    "        candidate_relation_embed = self.relation_embedding(next_relations.long())\n",
    "        candidate_entities_embed = self.entity_embedding(next_entities.long())\n",
    "\n",
    "        query_embed = self.relation_embedding(query_relation.long()).unsqueeze(1)\n",
    "\n",
    "        hist_hidden = self.embedding_encoder(torch.cat([hist_entities_embed, \n",
    "                                                              hist_relation_embed,\n",
    "                                                              query_embed.repeat(1, l, 1)], -1))\n",
    "        candidate_hidden = self.embedding_encoder(torch.cat([candidate_entities_embed, \n",
    "                                                                   candidate_relation_embed,\n",
    "                                                                   query_embed.repeat(1, n, 1)], -1))\n",
    "\n",
    "        seq = torch.cat([hist_hidden, candidate_hidden], 1)\n",
    "\n",
    "        seq_mask = torch.eye(seq.shape[1]).to(seq.device)\n",
    "        seq_mask[:l, :l] = torch.tril(torch.ones(l,l)).to(seq.device)\n",
    "        seq_mask[l:, :l] = 1\n",
    "\n",
    "        action_hidden = self.sequence_encoder(src = seq, mask = ~seq_mask.bool())[:, (-n):]\n",
    "        scores = self.output_layer(action_hidden).squeeze(-1)\n",
    "        return scores\n",
    "\n",
    "    def step(self, next_relations, next_entities, path_relation, path_entities, query_relation):\n",
    "\n",
    "        prelim_scores = self.get_scores(next_relations, next_entities, path_relation, \n",
    "                                       path_entities, query_relation).squeeze(-1)\n",
    "\n",
    "        comparison_tensor = torch.ones_like(next_relations).int() * self.rPAD  # matrix to compare\n",
    "        mask = next_relations == comparison_tensor  # The mask\n",
    "        dummy_scores = torch.ones_like(prelim_scores) * -99999.0  # the base matrix to choose from if dummy relation\n",
    "        scores = torch.where(mask, dummy_scores, prelim_scores)  # [original batch_size * num_rollout, max_num_actions]\n",
    "\n",
    "        action = torch.distributions.categorical.Categorical(logits=scores) # [original batch_size * num_rollout, 1]\n",
    "        label_action = action.sample() # [original batch_size * num_rollout,]\n",
    "\n",
    "        loss = torch.nn.CrossEntropyLoss(reduce=False)(scores, label_action)\n",
    "\n",
    "        chosen_relation = next_relations[torch.arange(len(label_action)), label_action]\n",
    "\n",
    "        return loss, F.log_softmax(scores), label_action, chosen_relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4abbb33",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import time\n",
    "import os\n",
    "import logging\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "\n",
    "#from model.agent import Agent\n",
    "from model.options import read_options\n",
    "from model.environment import env\n",
    "import codecs\n",
    "from collections import defaultdict\n",
    "import gc\n",
    "import resource\n",
    "import sys\n",
    "from model.baseline import ReactiveBaseline\n",
    "from scipy.special import logsumexp as lse\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from model.nell_eval import nell_eval\n",
    "\n",
    "logger = logging.getLogger()\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)\n",
    "\n",
    "\n",
    "class Trainer(object):\n",
    "    def __init__(self, params):\n",
    "\n",
    "        # transfer parameters to self\n",
    "        for key, val in params.items(): setattr(self, key, val);\n",
    "        self.device = params['device']\n",
    "        print(self.device)\n",
    "        self.agent = Agent(params).to(self.device)\n",
    "        #self.c_agent = ClusterAgent(params).to(self.device)\n",
    "        self.model_dir = params['model_dir']\n",
    "        self.save_path = self.model_dir + \"model\" + '.ckpt'\n",
    "        self.train_environment = env(params, 'train')\n",
    "        self.dev_test_environment = env(params, 'dev')\n",
    "        self.test_test_environment = env(params, 'test')\n",
    "        self.test_environment = self.dev_test_environment\n",
    "        self.rev_relation_vocab = self.train_environment.grapher.rev_relation_vocab\n",
    "        self.rev_entity_vocab = self.train_environment.grapher.rev_entity_vocab\n",
    "        #self.rev_cluster_relation_vocab = self.train_environment.cluster_grapher.rev_cluster_relation_vocab\n",
    "        #self.rev_cluster_vocab = self.train_environment.cluster_grapher.rev_cluster_vocab\n",
    "\n",
    "        self.max_hits_at_10 = 0\n",
    "        self.ePAD = self.entity_vocab['PAD']\n",
    "        self.rPAD = self.relation_vocab['PAD']\n",
    "        self.decaying_beta_init = self.beta\n",
    "        # optimize\n",
    "        self.baseline = ReactiveBaseline(params, self.Lambda)\n",
    "\n",
    "        self.decay_batch = None\n",
    "        self.gamma = params['gamma']\n",
    "        self.grad_clip_norm = params['grad_clip_norm']\n",
    "        self.eval_every = params['eval_every']\n",
    "        self.total_iterations = params['total_iterations']\n",
    "        self.learning_rate = params['learning_rate']\n",
    "        self.pool = params['pool']\n",
    "        self.output_dir = params['output_dir']\n",
    "        \n",
    "        self.positive_reward_rates = []\n",
    "        self.optimizer = optim.Adam(list(self.agent.parameters()),\n",
    "                                    lr=self.learning_rate)\n",
    "        self.two_embeds_sim_criterion = torch.nn.KLDivLoss()\n",
    "\n",
    "    def calc_reinforce_loss(self, all_loss, all_logits, cum_discounted_reward, decaying_beta, baseline):\n",
    "\n",
    "        loss = torch.stack(all_loss, dim=1)  # [original batch_size * num_rollout, T]\n",
    "        base_value = baseline.get_baseline_value()\n",
    "\n",
    "        # multiply with rewards\n",
    "        final_reward = cum_discounted_reward - base_value\n",
    "        reward_mean = torch.mean(final_reward)\n",
    "\n",
    "        # Constant added for numerical stability\n",
    "        reward_std = torch.std(final_reward) + 1e-6\n",
    "        final_reward = torch.div(final_reward - reward_mean, reward_std)\n",
    "\n",
    "        loss = torch.mul(loss, final_reward)  # [original batch_size * num_rollout, T]\n",
    "\n",
    "        entropy_loss = decaying_beta * self.entropy_reg_loss(all_logits)\n",
    "\n",
    "        total_loss = torch.mean(loss) - entropy_loss  # scalar\n",
    "\n",
    "        return total_loss\n",
    "    \n",
    "    def entropy_reg_loss(self, all_logits):  # control diversity\n",
    "        all_logits = torch.stack(all_logits, dim=2)  # [original batch_size * num_rollout, max_num_actions, T]\n",
    "        entropy_loss = - torch.mean(torch.sum(torch.mul(torch.exp(all_logits), all_logits), dim=1))  # scalar\n",
    "        return entropy_loss\n",
    "\n",
    "    def calc_cum_discounted_reward(self, rewards):\n",
    "\n",
    "        running_add = torch.zeros([rewards.size(0)]).to(self.device)  # [original batch_size * num_rollout]\n",
    "        cum_disc_reward = torch.zeros([rewards.size(0), self.path_length]).to(\n",
    "            self.device)  # [original batch_size * num_rollout, T]\n",
    "        cum_disc_reward[:,\n",
    "        self.path_length - 1] = rewards  # set the last time step to the reward received at the last state\n",
    "        for t in reversed(range(self.path_length)):\n",
    "            running_add = self.gamma * running_add + cum_disc_reward[:, t]\n",
    "            cum_disc_reward[:, t] = running_add\n",
    "        return cum_disc_reward\n",
    "\n",
    "    def calc_cum_discounted_reward_credit(self, entity_rewards):\n",
    "\n",
    "        num_instances = entity_rewards.size(0)\n",
    "        running_add = torch.zeros([num_instances]).to(self.device)  # [original batch_size * num_rollout]\n",
    "        cum_disc_reward = torch.zeros([num_instances, self.path_length]).to(\n",
    "            self.device)  # [original batch_size * num_rollout, T]\n",
    "        cum_disc_reward[:,\n",
    "        self.path_length - 1] = entity_rewards  # set the last time step to the reward received at the last state\n",
    "\n",
    "        for t in reversed(range(1, self.path_length)):\n",
    "            running_add = self.gamma * running_add + cum_disc_reward[:, t] # approx_credits[t].to(self.device) * cluster_rewards\n",
    "            cum_disc_reward[:, t-1] = running_add\n",
    "\n",
    "        return cum_disc_reward\n",
    "    \n",
    "    def train(self):\n",
    "        train_loss = []\n",
    "        train_reward = []\n",
    "\n",
    "        start_time = time.time()\n",
    "        self.batch_counter = 0\n",
    "        current_decay = self.decaying_beta_init\n",
    "        current_decay_count = 0\n",
    "\n",
    "        print('Agent start learning ...')\n",
    "        for entity_episode in self.train_environment.get_episodes():\n",
    "\n",
    "            self.batch_counter += 1\n",
    "\n",
    "            current_decay_count += 1\n",
    "            if current_decay_count == self.decay_batch:\n",
    "                current_decay *= self.decay_rate\n",
    "                current_decay_count = 0\n",
    "\n",
    "            # get initial state for entity agent\n",
    "\n",
    "            entity_state = entity_episode.get_state()\n",
    "            next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(\n",
    "                self.device)\n",
    "            next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "\n",
    "            prev_relation = self.agent.dummy_start_label.to(self.device)\n",
    "\n",
    "            query_relation = entity_episode.get_query_relation()\n",
    "            query_relation = torch.tensor(query_relation).long().to(self.device)\n",
    "            current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "\n",
    "            all_losses = []\n",
    "            all_logits = []\n",
    "            all_action_id = []\n",
    "            path_relation = [prev_relation]\n",
    "            path_entities = [current_entities]\n",
    "\n",
    "            for i in range(self.path_length):\n",
    "                loss, logits, idx, chosen_relation = self.agent.step(\n",
    "                    next_possible_relations,\n",
    "                    next_possible_entities, path_relation,\n",
    "                    path_entities, query_relation\n",
    "                )\n",
    "\n",
    "                entity_state = entity_episode(idx.cpu())\n",
    "                next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "                prev_relation = chosen_relation.to(self.device)\n",
    "\n",
    "                all_losses.append(loss)\n",
    "                all_logits.append(logits)\n",
    "                all_action_id.append(idx)\n",
    "                path_relation.append(prev_relation)\n",
    "                path_entities.append(current_entities)\n",
    "\n",
    "            rewards = entity_episode.get_reward()\n",
    "            rewards = torch.tensor(rewards).to(self.device)\n",
    "\n",
    "            cum_discounted_reward = self.calc_cum_discounted_reward(rewards)\n",
    "            reinforce_loss = self.calc_reinforce_loss(all_losses, all_logits, cum_discounted_reward,\n",
    "                                                        current_decay, self.baseline)\n",
    "\n",
    "            self.baseline.update(torch.mean(cum_discounted_reward))\n",
    "\n",
    "            self.optimizer.zero_grad()\n",
    "            reinforce_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=self.grad_clip_norm, norm_type=2)\n",
    "            self.optimizer.step()\n",
    "\n",
    "            train_loss.append(reinforce_loss.detach().cpu().item())\n",
    "            train_reward.append(rewards.cpu().float().tolist())\n",
    "\n",
    "            if (self.batch_counter > 0)&(self.batch_counter % (self.eval_every//10) == 0):\n",
    "                avg_loss = np.mean(train_loss[-(self.eval_every//10):])\n",
    "                avg_reward = np.mean(sum(train_reward[-(self.eval_every//10):], []))\n",
    "                print('Iteration: {}, Train loss: {:.4f}, rewards: {:.4f}'.format(self.batch_counter, avg_loss, avg_reward))\n",
    "                gc.collect()\n",
    "\n",
    "            if (self.batch_counter > 0)&(self.batch_counter % self.eval_every == 0):\n",
    "                print('Eval:')\n",
    "                self.test(beam = True)\n",
    "                gc.collect()\n",
    "                print('------------------------------------------------------------')\n",
    "\n",
    "            if self.batch_counter > self.total_iterations:\n",
    "                break\n",
    "\n",
    "    def test(self, beam=False, print_paths=False, save_model=True):\n",
    "\n",
    "        with torch.no_grad():\n",
    "\n",
    "            batch_counter = 0\n",
    "            paths = defaultdict(list)\n",
    "            answers = []\n",
    "            all_final_reward_1 = 0\n",
    "            all_final_reward_3 = 0\n",
    "            all_final_reward_5 = 0\n",
    "            all_final_reward_10 = 0\n",
    "            all_final_reward_20 = 0\n",
    "            auc = 0\n",
    "\n",
    "            total_examples = self.test_environment.total_no_examples\n",
    "\n",
    "            for entity_episode in self.test_environment.get_episodes():\n",
    "                batch_counter += 1\n",
    "\n",
    "                temp_batch_size = entity_episode.no_examples\n",
    "\n",
    "                self.qr = entity_episode.get_query_relation()\n",
    "                query_relation = self.qr\n",
    "                query_relation = torch.tensor(query_relation).long().to(self.device)\n",
    "                # set initial beam probs\n",
    "                beam_probs = torch.zeros((temp_batch_size * self.test_rollouts, 1)).to(self.device)\n",
    "\n",
    "                # get initial state for entity agent\n",
    "                entity_state = entity_episode.get_state()\n",
    "\n",
    "                next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "\n",
    "                prev_relation = (torch.ones(temp_batch_size * self.test_rollouts) * self.relation_vocab[\n",
    "                    'DUMMY_START_RELATION']).long().to(self.device)\n",
    "\n",
    "                if print_paths:\n",
    "                    self.entity_trajectory = [current_entities]\n",
    "                    self.relation_trajectory = [prev_relation]        \n",
    "\n",
    "                self.log_probs = np.zeros((temp_batch_size * self.test_rollouts,)) * 1.0\n",
    "                \n",
    "                path_relation = [prev_relation]\n",
    "                path_entities = [current_entities]\n",
    "                \n",
    "                for i in range(self.path_length):\n",
    "\n",
    "                    loss, test_scores, test_action_idx, chosen_relation = self.agent.step(\n",
    "                                                                        next_relations,\n",
    "                                                                        next_entities, path_relation,\n",
    "                                                                        path_entities, query_relation\n",
    "                                                                    )\n",
    "                    \n",
    "                    #Mimic original implementation on pytorch\n",
    "                    if beam:\n",
    "                        k = self.test_rollouts\n",
    "                        beam_probs = beam_probs.to(self.device)\n",
    "                        new_scores = test_scores + beam_probs\n",
    "                        new_scores = new_scores.cpu()\n",
    "                        if i == 0:\n",
    "                            idx = np.argsort(new_scores)\n",
    "                            idx = idx[:, -k:]\n",
    "                            ranged_idx = np.tile([b for b in range(k)], temp_batch_size)\n",
    "                            idx = idx[np.arange(k * temp_batch_size), ranged_idx]\n",
    "                        else:\n",
    "                            idx = self.top_k(new_scores, k)\n",
    "\n",
    "                        y = idx // self.max_num_actions\n",
    "                        x = idx % self.max_num_actions\n",
    "\n",
    "                        y += np.repeat([b * k for b in range(temp_batch_size)], k)\n",
    "                        entity_state['current_entities'] = entity_state['current_entities'][y]\n",
    "                        entity_state['next_relations'] = entity_state['next_relations'][y, :]\n",
    "                        entity_state['next_entities'] = entity_state['next_entities'][y, :]\n",
    "\n",
    "                        test_action_idx = x\n",
    "                        chosen_relation = entity_state['next_relations'][np.arange(temp_batch_size * k), x]\n",
    "\n",
    "                        beam_probs = new_scores[y, x]\n",
    "                        beam_probs = beam_probs.reshape((-1, 1))\n",
    "\n",
    "#                     #My implementation to fit arbitrary dimension\n",
    "#                     if beam:\n",
    "#                         k = self.test_rollouts\n",
    "#                         beam_probs = beam_probs.to(self.device)\n",
    "#                         new_scores = test_scores + beam_probs\n",
    "#                         new_scores = new_scores.cpu()\n",
    "#                         if i == 0:\n",
    "#                             reshape_score = new_scores.reshape(temp_batch_size, self.test_rollouts, -1)\n",
    "#                             possible_idx = []\n",
    "#                             for x in reshape_score:\n",
    "#                                 possible_idx.append(torch.LongTensor(np.where(x[0].cpu() > -1000)[0]))\n",
    "#                             idx = []\n",
    "#                             for x in possible_idx:\n",
    "#                                 idx.append(torch.cat(([x]*(self.test_rollouts//len(x) + 1)))[:self.test_rollouts])\n",
    "#                             idx = torch.cat(idx, 0)\n",
    "#                         else:\n",
    "#                             idx = self.top_k(new_scores, k)\n",
    "\n",
    "#                         y = idx // self.max_num_actions\n",
    "#                         x = idx % self.max_num_actions\n",
    "\n",
    "#                         y += np.repeat([b * k for b in range(temp_batch_size)], k)\n",
    "#                         entity_state['current_entities'] = entity_state['current_entities'][y]\n",
    "#                         entity_state['next_relations'] = entity_state['next_relations'][y, :]\n",
    "#                         entity_state['next_entities'] = entity_state['next_entities'][y, :]\n",
    "#                         entity_state_emb = entity_state_emb[:, :, y, :]\n",
    "\n",
    "#                         test_action_idx = x\n",
    "#                         chosen_relation = entity_state['next_relations'][np.arange(temp_batch_size * k), x]\n",
    "#                         beam_probs = new_scores[y, x]\n",
    "#                         beam_probs = beam_probs.reshape((-1, 1))\n",
    "                        \n",
    "                        if print_paths:\n",
    "                            for j in range(i):\n",
    "                                self.entity_trajectory[j] = self.entity_trajectory[j][y]\n",
    "                                self.relation_trajectory[j] = self.relation_trajectory[j][y]\n",
    "\n",
    "                    entity_state = entity_episode(test_action_idx.cpu().numpy())\n",
    "                    next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)\n",
    "                    next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)\n",
    "                    current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)\n",
    "                    prev_relation = torch.tensor(chosen_relation).long().to(self.device)\n",
    "                    \n",
    "                    path_relation.append(prev_relation)\n",
    "                    path_entities.append(current_entities)\n",
    "                    \n",
    "                    if print_paths:\n",
    "                        self.entity_trajectory.append(entity_state['current_entities'])\n",
    "                        self.relation_trajectory.append(chosen_relation)\n",
    "\n",
    "                    test_scores = test_scores.cpu().numpy()\n",
    "                    self.log_probs += test_scores[np.arange(self.log_probs.shape[0]), test_action_idx.cpu().numpy()]\n",
    "\n",
    "                if beam:\n",
    "                    self.log_probs = beam_probs\n",
    "\n",
    "                rewards = entity_episode.get_reward()  # [B*test_rollouts]\n",
    "                reward_reshape = np.reshape(rewards, (temp_batch_size, self.test_rollouts))  # [orig_batch, test_rollouts]\n",
    "                self.log_probs = np.reshape(self.log_probs, (temp_batch_size, self.test_rollouts))\n",
    "                sorted_indx = np.argsort(-self.log_probs)\n",
    "                final_reward_1 = 0\n",
    "                final_reward_3 = 0\n",
    "                final_reward_5 = 0\n",
    "                final_reward_10 = 0\n",
    "                final_reward_20 = 0\n",
    "                AP = 0\n",
    "                ce = entity_episode.state['current_entities'].reshape((temp_batch_size, self.test_rollouts))\n",
    "                se = entity_episode.start_entities.reshape((temp_batch_size, self.test_rollouts))\n",
    "                for b in range(temp_batch_size):\n",
    "                    answer_pos = None\n",
    "                    seen = set()\n",
    "                    pos=0\n",
    "                    if self.pool == 'max':\n",
    "                        for r in sorted_indx[b]:\n",
    "                            if reward_reshape[b,r] == self.positive_reward:\n",
    "                                answer_pos = pos\n",
    "                                break\n",
    "                            if ce[b, r] not in seen:\n",
    "                                seen.add(ce[b, r])\n",
    "                                pos += 1\n",
    "                    if self.pool == 'sum':\n",
    "                        scores = defaultdict(list)\n",
    "                        answer = ''\n",
    "                        for r in sorted_indx[b]:\n",
    "                            scores[ce[b,r]].append(self.log_probs[b,r])\n",
    "                            if reward_reshape[b,r] == self.positive_reward:\n",
    "                                answer = ce[b,r]\n",
    "                        final_scores = defaultdict(float)\n",
    "                        for e in scores:\n",
    "                            final_scores[e] = lse(scores[e])\n",
    "                        sorted_answers = sorted(final_scores, key=final_scores.get, reverse=True)\n",
    "                        if answer in  sorted_answers:\n",
    "                            answer_pos = sorted_answers.index(answer)\n",
    "                        else:\n",
    "                            answer_pos = None\n",
    "\n",
    "\n",
    "                    if answer_pos != None:\n",
    "                        if answer_pos < 20:\n",
    "                            final_reward_20 += 1\n",
    "                            if answer_pos < 10:\n",
    "                                final_reward_10 += 1\n",
    "                                if answer_pos < 5:\n",
    "                                    final_reward_5 += 1\n",
    "                                    if answer_pos < 3:\n",
    "                                        final_reward_3 += 1\n",
    "                                        if answer_pos < 1:\n",
    "                                            final_reward_1 += 1\n",
    "                    if answer_pos == None:\n",
    "                        AP += 0\n",
    "                    else:\n",
    "                        AP += 1.0/((answer_pos+1))\n",
    "                    \n",
    "                    if print_paths:\n",
    "                        qr = self.train_environment.grapher.rev_relation_vocab[self.qr[b * self.test_rollouts]]\n",
    "                        start_e = self.rev_entity_vocab[entity_episode.start_entities[b * self.test_rollouts]]\n",
    "                        end_e = self.rev_entity_vocab[entity_episode.end_entities[b * self.test_rollouts]]\n",
    "                        paths[str(qr)].append(str(start_e) + \"\\t\" + str(end_e) + \"\\n\")\n",
    "                        paths[str(qr)].append(\"Reward:\" + str(1 if answer_pos != None and answer_pos < 10 else 0) + \"\\n\")\n",
    "                        for r in sorted_indx[b]:\n",
    "                            indx = b * self.test_rollouts + r\n",
    "                            if rewards[indx] == self.positive_reward:\n",
    "                                rev = 1\n",
    "                            else:\n",
    "                                rev = -1\n",
    "                            answers.append(self.rev_entity_vocab[se[b,r]]+'\\t'+ self.rev_entity_vocab[ce[b,r]]+'\\t'+ str(self.log_probs[b,r])+'\\n')\n",
    "                            paths[str(qr)].append(\n",
    "                                '\\t'.join([str(self.rev_entity_vocab[e[indx]]) for e in\n",
    "                                           self.entity_trajectory]) + '\\n' + '\\t'.join(\n",
    "                                    [str(self.rev_relation_vocab[re[indx]]) for re in self.relation_trajectory]) + '\\n' + str(\n",
    "                                    rev) + '\\n' + str(\n",
    "                                    self.log_probs[b, r]) + '\\n___' + '\\n')\n",
    "                        paths[str(qr)].append(\"#####################\\n\")\n",
    "\n",
    "                all_final_reward_1 += final_reward_1\n",
    "                all_final_reward_3 += final_reward_3\n",
    "                all_final_reward_5 += final_reward_5\n",
    "                all_final_reward_10 += final_reward_10\n",
    "                all_final_reward_20 += final_reward_20\n",
    "                auc += AP\n",
    "\n",
    "            all_final_reward_1 /= total_examples\n",
    "            all_final_reward_3 /= total_examples\n",
    "            all_final_reward_5 /= total_examples\n",
    "            all_final_reward_10 /= total_examples\n",
    "            all_final_reward_20 /= total_examples\n",
    "            auc /= total_examples\n",
    "            \n",
    "            if save_model:\n",
    "                if all_final_reward_10 >= self.max_hits_at_10:\n",
    "                    self.max_hits_at_10 = all_final_reward_10\n",
    "                    torch.save(self.agent.state_dict(), self.model_dir + \"agent\" + '.ckpt')\n",
    "                    # self.save_path = self.model_dir + \"model\" + '.ckpt'\n",
    "\n",
    "            if print_paths:\n",
    "                logger.info(\"[ printing paths at {} ]\".format(self.output_dir + '/test_beam/'))\n",
    "                for q in paths:\n",
    "                    j = q.replace('/', '-')\n",
    "                    with codecs.open(self.path_logger_file_ + '_' + j, 'a', 'utf-8') as pos_file:\n",
    "                        for p in paths[q]:\n",
    "                            pos_file.write(p)\n",
    "                with open(self.path_logger_file_ + 'answers', 'w') as answer_file:\n",
    "                    for a in answers:\n",
    "                        answer_file.write(a)\n",
    "\n",
    "            with open(self.output_dir + '/scores.txt', 'a') as score_file:\n",
    "                score_file.write(\"Hits@1: {:.4f}\".format(all_final_reward_1))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@3: {:.4f}\".format(all_final_reward_3))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@5: {:.4f}\".format(all_final_reward_5))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@10: {:.4f}\".format(all_final_reward_10))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"Hits@20: {:.4f}\".format(all_final_reward_20))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"MRR: {:.4f}\".format(auc))\n",
    "                score_file.write(\"\\n\")\n",
    "                score_file.write(\"------------------------------------\")\n",
    "\n",
    "            print(\"Hits@1: {:.4f}, Hits@3: {:.4f}, Hits@10: {:.4f}, MRR: {:.4f}\".format(all_final_reward_1, \n",
    "                                                                                 all_final_reward_3,\n",
    "                                                                                 all_final_reward_10, auc))\n",
    "            \n",
    "    def top_k(self, scores, k):\n",
    "        scores = scores.reshape(-1, k * self.max_num_actions)  # [B, (k*max_num_actions)]\n",
    "        idx = np.argsort(scores, axis=1)\n",
    "        idx = idx[:, -k:]  # take the last k highest indices # [B , k]\n",
    "        return idx.reshape((-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7692b714",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = {}\n",
    "\n",
    "#basic setting\n",
    "options['use_cuda'] = True\n",
    "options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/vocab/'\n",
    "options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/WN18RR/'\n",
    "options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "options['model_dir'] = './outputs_v2/'\n",
    "options['output_dir'] = './outputs_v2/'\n",
    "\n",
    "#agent setting\n",
    "options['pretrained_embeddings_relation'] = {}\n",
    "options['pretrained_embeddings_entity'] = {}\n",
    "options['embedding_size'] = 50\n",
    "options['hidden_size'] = 200\n",
    "options['use_entity_embeddings'] = 0\n",
    "options['train_entity_embeddings'] = 0\n",
    "options['train_relation_embeddings'] = 1\n",
    "options['path_length'] = 3\n",
    "options['LSTM_layers'] = 1\n",
    "options['max_num_actions'] = 128\n",
    "\n",
    "#hyperparameters\n",
    "options['test_rollouts'] = 100\n",
    "options['num_rollouts'] = 20\n",
    "options['batch_size'] = 128\n",
    "options['beta'] = 0.05\n",
    "options['Lambda'] = 0.05\n",
    "options['gamma'] = 1\n",
    "options['positive_reward'] = 1\n",
    "options['negative_reward'] = 0\n",
    "options['learning_rate'] = 1e-3\n",
    "options['grad_clip_norm'] = 100\n",
    "options['eval_every'] = 100\n",
    "options['total_iterations'] = 1000\n",
    "options['pool'] = 'max'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8dc1e4ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(options)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6527dc49",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kelvin/miniconda/envs/GraphRL/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/tmp/ipykernel_67070/2123898640.py:127: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, F.log_softmax(scores), label_action, chosen_relation\n",
      "/tmp/ipykernel_67070/3382145515.py:281: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.3653, Hits@3: 0.4336, Hits@10: 0.5041, MRR: 0.4109\n"
     ]
    }
   ],
   "source": [
    "trainer.agent.load_state_dict(torch.load(options['model_dir'] + 'agent.ckpt'))\n",
    "trainer.agent.eval()\n",
    "\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "trainer.test_environment.test_rollouts = 100\n",
    "trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
