from __future__ import absolute_import
from __future__ import division
from tqdm import tqdm
import json
import time
import os
import logging
import numpy as np
import sys
from copy import deepcopy
sys.path.append('./code')

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.utils as utils
from torch_scatter import scatter

#from model.agent import Agent
from model.options import read_options
from model.environment import env
import codecs
from collections import defaultdict
import gc
import resource
import sys
from model.baseline import ReactiveBaseline
from scipy.special import logsumexp as lse
import torch
import torch.optim as optim
from model.nell_eval import nell_eval

def get_logger(output_dir):
    filename=output_dir +'train'
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

logger = logging.getLogger()
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

class node_aggregation(nn.Module):
    def __init__(self, m, embedding_size, hidden_size, layer = 2):
        super(node_aggregation, self).__init__()
        
        self.m = m
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.layer = layer
        
        self.query_encoder = nn.Linear(2*embedding_size, 2*embedding_size)
        self.key_encoder = nn.Linear(2*embedding_size, 2*embedding_size)
        self.value_encoder = nn.Linear(2*2*embedding_size, 2*embedding_size)
        
        self.mlp = nn.Sequential(nn.Linear(2*embedding_size, 2*2*embedding_size), nn.Mish(),
                                 nn.Linear(2*2*embedding_size, 2*embedding_size))
        self.mlp_norm = nn.LayerNorm(2*embedding_size)
        self.norm = nn.LayerNorm(2*embedding_size)

    def forward(self, H, L, T, X, Z):
        
        q = self.query_encoder(X[H] + Z[L])
        k = self.key_encoder(X[T])
        
        v = self.value_encoder(torch.cat([X[T], Z[L]], -1))

        score = (q*k).sum(-1)/np.sqrt(L.shape[-1])
        norm = scatter(torch.exp(score), H.long(), dim = 0, reduce = 'sum')
        norm = norm[H.long()]
        att = torch.exp(score)/norm

        x_ = scatter(v*att.unsqueeze(-1), H.long(), dim = 0, reduce = 'sum')
        
        x = X
        x[:len(x_)] += x_
        x = self.norm(x)
        
        x = self.mlp_norm(x + self.mlp(x))
        return x
    
class relation_aggregation(nn.Module):
    def __init__(self, m, embedding_size, hidden_size, layer = 2):
        super(relation_aggregation, self).__init__()
        
        self.m = m
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.layer = layer
        
        self.node_encoder = nn.Sequential(nn.Linear(2*embedding_size, 2*embedding_size), nn.Mish(),
                                          nn.Linear(2*embedding_size, 2*embedding_size))
        self.mlp = nn.Sequential(nn.Linear(2*embedding_size, 2*2*embedding_size), nn.Mish(),
                                 nn.Linear(2*2*embedding_size, 2*embedding_size))
        self.mlp_norm = nn.LayerNorm(2*embedding_size)
        self.norm = nn.LayerNorm(2*embedding_size)

    def forward(self, H, L, T, X, Z):
        h_x = self.node_encoder(X[H])
        z_ = scatter(h_x, L.long(), dim = 0, reduce = 'mean')
        
        z = Z
        z[:len(z_)] += z_
        z = self.norm(z)
        
        z = self.mlp_norm(z + self.mlp(z))
        return z

class Policy_step(nn.Module):
    def __init__(self, m, embedding_size, hidden_size):
        super(Policy_step, self).__init__()

        self.batch_norm = nn.BatchNorm1d(m * hidden_size)
        self.lstm_cell = nn.LSTMCell(input_size= (m * embedding_size), 
                                     hidden_size= m * hidden_size)
        self.l1 = nn.Linear(m * embedding_size,  m * embedding_size)
        self.l2 = nn.Linear( m * hidden_size, m * hidden_size)
        self.l3 = nn.Linear( m * hidden_size, m * hidden_size)

    def forward(self, prev_gnn, prev_action, prev_state):
        prev_action = torch.relu(self.l1(prev_action))
        
        #prev_input = torch.cat([prev_gnn, prev_action], -1)
        prev_input = prev_action
        output, ch = self.lstm_cell(prev_input, prev_state)
        output = torch.relu(self.l2(output))
        ch = torch.relu(self.l3(ch))

        ch = torch.cat([output.unsqueeze(0).unsqueeze(0), ch.unsqueeze(0).unsqueeze(0)], dim=1)

        return output, ch

class Policy_mlp(nn.Module):
    def __init__(self, hidden_size, m, embedding_size):
        super(Policy_mlp, self).__init__()

        self.hidden_size = hidden_size
        self.m = m
        self.embedding_size = embedding_size
        self.mlp_l1 = nn.Linear(m * self.hidden_size + 2 * 2* embedding_size, m * self.hidden_size, bias=True)
        self.mlp_l2 = nn.Linear(m * self.hidden_size, (m * self.embedding_size * 2), bias=True)

    def forward(self, state_query):
        # state_query = state_query.float()
        hidden = torch.relu(self.mlp_l1(state_query))
        output = torch.relu(self.mlp_l2(hidden))
        return output

class Agent(nn.Module):

    def __init__(self, params):
        super(Agent, self).__init__()
        self.action_vocab_size = len(params['relation_vocab'])
        self.entity_vocab_size = len(params['entity_vocab'])
        self.embedding_size = params['embedding_size']
        self.hidden_size = params['hidden_size']
        self.ePAD = params['entity_vocab']['PAD']
        self.rPAD = params['relation_vocab']['PAD']
        self.use_entity_embeddings = params['use_entity_embeddings']
        self.train_entity_embeddings = params['train_entity_embeddings']
        self.train_relation_embeddings = params['train_relation_embeddings']
        self.device = params['device']
        self.gnn_layer = params['gnn_layer']

        if self.use_entity_embeddings:
            if self.train_entity_embeddings:
                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size)
            else:
                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size).requires_grad_(
                    False)
            torch.nn.init.xavier_uniform_(self.entity_embedding.weight)
        else:
            if self.train_entity_embeddings:
                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size)
            else:
                self.entity_embedding = nn.Embedding(self.entity_vocab_size, 2 * self.embedding_size).requires_grad_(
                    False)
            torch.nn.init.constant_(self.entity_embedding.weight, 0.0)

        if self.train_relation_embeddings:
            self.relation_embedding = nn.Embedding(self.action_vocab_size, 2 * self.embedding_size)
        else:
            self.relation_embedding = nn.Embedding(self.action_vocab_size, 2 * self.embedding_size).requires_grad_(
                False)
        torch.nn.init.xavier_uniform_(self.relation_embedding.weight)

        # self.relation_embedding = params['pretrained_embeddings_relation']

        # self.train_entities = params['train_entity_embeddings']
        # self.train_relations = params['train_relation_embeddings']

        self.num_rollouts = params['num_rollouts']
        self.test_rollouts = params['test_rollouts']
        self.LSTM_Layers = params['LSTM_layers']
        self.batch_size = params['batch_size'] * params['num_rollouts']
        self.dummy_start_label = (torch.ones(self.batch_size) * params['relation_vocab']['DUMMY_START_RELATION']).long()
        # print(self.dummy_start_label.size())
        self.entity_embedding_size = self.embedding_size
        self.initial_gnn_state = nn.Parameter(torch.zeros((1, self.embedding_size*2*2)))
        torch.nn.init.xavier_uniform_(self.initial_gnn_state)

        if self.use_entity_embeddings:
            self.m = 4
        else:
            self.m = 2

        self.policy_step = Policy_step(m=self.m, embedding_size=self.embedding_size, hidden_size=self.hidden_size).to(self.device)
        self.policy_mlp = Policy_mlp(self.hidden_size, self.m, self.embedding_size).to(self.device)
        
        self.node_conv = nn.ModuleList()
        self.rel_conv = nn.ModuleList()
        for i in range(self.gnn_layer):
            self.node_conv.append(node_aggregation(m=self.m, 
                                                   embedding_size=self.embedding_size, 
                                                   hidden_size=self.hidden_size).to(self.device))
            self.rel_conv.append(relation_aggregation(m=self.m, 
                                                      embedding_size=self.embedding_size, 
                                                      hidden_size=self.hidden_size).to(self.device))
        
        self.gate1_linear = nn.Linear(2*self.hidden_size, 3*2*self.hidden_size)
        self.gate2_linear = nn.Linear(2*self.hidden_size, 3*2*self.hidden_size)
        
        self.state_encoder = nn.Sequential(nn.Linear(self.m*self.hidden_size*2 + self.embedding_size*2,
                                           self.embedding_size*2*2), nn.Mish(), 
                                           nn.LayerNorm(self.embedding_size*2*2))


    def get_mem_shape(self):
        return (self.LSTM_Layers, 2, None, self.m * self.hidden_size)


    def action_encoder(self, next_relations, next_entities):
        # relation_embedding = self.relation_embedding[next_relations.cpu().numpy()]
        # entity_embedding = self.entity_embedding[next_entities.cpu().numpy()]
        relation_embedding = self.rel_embed[next_relations]
        entity_embedding = self.ent_embed[next_entities]

        if self.use_entity_embeddings:
            action_embedding = torch.cat([relation_embedding, entity_embedding], dim=-1)
        else:
            action_embedding = relation_embedding

        return action_embedding
    
    def neighbour_aggregation(self, query_relation, prev_state, next_neighbors):
        link = next_neighbors[:, :, :, 1]
        tail = next_neighbors[:, :, :, 0]
        mask = (link != 0).float()

        t_embed = self.entity_embedding(tail)
        r_embed = self.relation_embedding(link)

        lstm_state = torch.cat([prev_state[0], prev_state[1]], -1)
        query_embedding = self.relation_embedding(query_relation.long())
        state = self.state_encoder(torch.cat([lstm_state, query_embedding], -1))
        neighbor_embedding = torch.cat([t_embed, r_embed], -1)

        att = (state.unsqueeze(1).unsqueeze(1)*neighbor_embedding).sum(-1)/np.sqrt(state.shape[-1])
        att = F.softmax(att - (1 - mask)*1e8, 2)
        update_embedding = (neighbor_embedding*att.unsqueeze(-1)).sum(2)
        return update_embedding
    
    def mgat(self, head, link, tail):
        n_ent = len(self.entity_embedding.weight)
        n_rel = len(self.relation_embedding.weight)

        self.ent_embed = self.entity_embedding(torch.arange(n_ent).to(head.device))
        self.rel_embed = self.relation_embedding(torch.arange(n_rel).to(link.device))

        for i in range(self.gnn_layer):
            self.ent_embed = self.node_conv[i](head, link, tail, 
                                                    self.ent_embed, self.rel_embed)
            self.rel_embed = self.rel_conv[i](head, link, tail, 
                                                    self.ent_embed, self.rel_embed)

    def step(self, next_relations, next_entities, prev_state, prev_relation, query_embedding, current_entities, next_neighbors, prev_gnn_state):

        prev_action_embedding = self.action_encoder(prev_relation, current_entities) # (original batch_size * num_rollout, 4*self.embedding_size)

        prev_state = torch.unbind(prev_state, dim=1)
        prev_state = [prev_state[0].squeeze(0), prev_state[1].squeeze(0)]

        new_prev_state = list()

        output, new_state = self.policy_step(prev_gnn_state, prev_action_embedding, prev_state)

        prev_entity = self.ent_embed[current_entities]
        if self.use_entity_embeddings:
            state = torch.cat([output, prev_entity], dim=-1)
        else:
            state = output

        candidate_action_embeddings = self.action_encoder(next_relations, next_entities)
        gnn_embedding = self.neighbour_aggregation(query_embedding, prev_state, next_neighbors)
        candidate_action_embeddings = torch.cat([candidate_action_embeddings, gnn_embedding], -1)

        query_embedding = self.rel_embed[query_embedding]
        state_query_concat = torch.cat([state, query_embedding], dim=-1)

        # MLP for policy#

        output = self.policy_mlp(state_query_concat)
        # print(output.size())
        output_expanded = torch.unsqueeze(output, dim=1)  # [original batch_size * num_rollout, 1, 2D], D=self.hidden_size
        # print(output_expanded.size(), candidate_action_embeddings.size())
        prelim_scores = torch.sum(candidate_action_embeddings * output_expanded, dim=2)

        # Masking PAD actions

        comparison_tensor = torch.ones_like(next_relations).int() * self.rPAD  # matrix to compare
        mask = next_relations == comparison_tensor  # The mask
        dummy_scores = torch.ones_like(prelim_scores) * -99999.0  # the base matrix to choose from if dummy relation
        scores = torch.where(mask, dummy_scores, prelim_scores)  # [original batch_size * num_rollout, max_num_actions]

        # 4 sample action
        action = torch.distributions.categorical.Categorical(logits=scores) # [original batch_size * num_rollout, 1]
        label_action = action.sample() # [original batch_size * num_rollout,]

        # loss
        # 5a.
        loss = torch.nn.CrossEntropyLoss(reduce=False)(scores, label_action)

        # 6. Map back to true id
        chosen_relation = next_relations[torch.arange(len(label_action)), label_action]

        return loss, new_state, F.log_softmax(scores), label_action, chosen_relation

class Trainer(object):
    def __init__(self, params):

        # transfer parameters to self
        for key, val in params.items(): setattr(self, key, val);
        self.device = params['device']
        print(self.device)
        self.agent = Agent(params).to(self.device)
        #self.c_agent = ClusterAgent(params).to(self.device)
        self.model_dir = params['model_dir']
        self.save_path = self.model_dir + "model" + '.ckpt'
        self.train_environment = env(params, 'train')
        
        params['batch_size'] = params['eval_batch_size']
        self.dev_test_environment = env(params, 'dev')
        self.test_test_environment = env(params, 'test')
        self.test_environment = self.dev_test_environment
        self.rev_relation_vocab = self.train_environment.grapher.rev_relation_vocab
        self.rev_entity_vocab = self.train_environment.grapher.rev_entity_vocab
        #self.rev_cluster_relation_vocab = self.train_environment.cluster_grapher.rev_cluster_relation_vocab
        #self.rev_cluster_vocab = self.train_environment.cluster_grapher.rev_cluster_vocab

        self.max_hits_at_10 = 0
        self.ePAD = self.entity_vocab['PAD']
        self.rPAD = self.relation_vocab['PAD']
        self.decaying_beta_init = self.beta
        # optimize
        self.baseline = ReactiveBaseline(params, self.Lambda)

        self.decay_batch = None
        self.gamma = params['gamma']
        self.grad_clip_norm = params['grad_clip_norm']
        self.eval_every = params['eval_every']
        self.total_iterations = params['total_iterations']
        self.learning_rate = params['learning_rate']
        self.pool = params['pool']
        self.output_dir = params['output_dir']
        self.params = params
        
        self.positive_reward_rates = []
        self.optimizer = optim.Adam(list(self.agent.parameters()),
                                    lr=self.learning_rate)
        self.two_embeds_sim_criterion = torch.nn.KLDivLoss()

    def calc_reinforce_loss(self, all_loss, all_logits, cum_discounted_reward, decaying_beta, baseline):

        loss = torch.stack(all_loss, dim=1)  # [original batch_size * num_rollout, T]
        base_value = baseline.get_baseline_value()

        # multiply with rewards
        final_reward = cum_discounted_reward - base_value
        reward_mean = torch.mean(final_reward)

        # Constant added for numerical stability
        reward_std = torch.std(final_reward) + 1e-6
        final_reward = torch.div(final_reward - reward_mean, reward_std)

        loss = torch.mul(loss, final_reward)  # [original batch_size * num_rollout, T]

        entropy_loss = decaying_beta * self.entropy_reg_loss(all_logits)

        total_loss = torch.mean(loss) - entropy_loss  # scalar

        return total_loss
    
    def entropy_reg_loss(self, all_logits):  # control diversity
        all_logits = torch.stack(all_logits, dim=2)  # [original batch_size * num_rollout, max_num_actions, T]
        entropy_loss = - torch.mean(torch.sum(torch.mul(torch.exp(all_logits), all_logits), dim=1))  # scalar
        return entropy_loss

    def calc_cum_discounted_reward(self, rewards):

        running_add = torch.zeros([rewards.size(0)]).to(self.device)  # [original batch_size * num_rollout]
        cum_disc_reward = torch.zeros([rewards.size(0), self.path_length]).to(
            self.device)  # [original batch_size * num_rollout, T]
        cum_disc_reward[:,
        self.path_length - 1] = rewards  # set the last time step to the reward received at the last state
        for t in reversed(range(self.path_length)):
            running_add = self.gamma * running_add + cum_disc_reward[:, t]
            cum_disc_reward[:, t] = running_add
        return cum_disc_reward

    def calc_cum_discounted_reward_credit(self, entity_rewards):

        num_instances = entity_rewards.size(0)
        running_add = torch.zeros([num_instances]).to(self.device)  # [original batch_size * num_rollout]
        cum_disc_reward = torch.zeros([num_instances, self.path_length]).to(
            self.device)  # [original batch_size * num_rollout, T]
        cum_disc_reward[:,
        self.path_length - 1] = entity_rewards  # set the last time step to the reward received at the last state

        for t in reversed(range(1, self.path_length)):
            running_add = self.gamma * running_add + cum_disc_reward[:, t] # approx_credits[t].to(self.device) * cluster_rewards
            cum_disc_reward[:, t-1] = running_add

        return cum_disc_reward

    def get_graph(self, entity_episode, mode = 'train'):
        if mode == 'train':
            graph = deepcopy(self.train_environment.grapher)
        else:
            graph = deepcopy(self.test_environment.grapher)
        
        if mode == 'train':
            ent_match = graph.array_store[entity_episode.start_entities,:,0] == np.expand_dims(entity_episode.end_entities, -1)
            rel_match = graph.array_store[entity_episode.start_entities,:,1] == np.expand_dims(entity_episode.query_relation, -1)

            tmp = graph.array_store[entity_episode.start_entities, :, :]
            tmp[rel_match&ent_match] = 0
            graph.array_store[entity_episode.start_entities, :, :] = tmp

        head = torch.repeat_interleave(torch.arange(len(graph.rev_entity_vocab)), self.params['max_num_actions'])
        tail = torch.LongTensor(graph.array_store[:, :, 0].reshape(-1))
        link = torch.LongTensor(graph.array_store[:, :, 1].reshape(-1))

        head = head[(link > 0)].to(self.device)
        tail = tail[(link > 0)].to(self.device)
        link = link[(link > 0)].to(self.device)
        return head, link, tail
    
    def train(self):
        train_loss = []
        train_reward = []

        start_time = time.time()
        self.batch_counter = 0
        current_decay = self.decaying_beta_init
        current_decay_count = 0

        print('Agent start learning ...')
        for entity_episode in self.train_environment.get_episodes():

            self.batch_counter += 1

            current_decay_count += 1
            if current_decay_count == self.decay_batch:
                current_decay *= self.decay_rate
                current_decay_count = 0

            # get initial state for entity agent
            head, link, tail = self.get_graph(entity_episode, mode = 'train')
            entity_state_emb = torch.zeros(1, 2, self.batch_size * self.num_rollouts,
                                           self.agent.m * self.hidden_size).to(self.device)
            entity_state = entity_episode.get_state()
            next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(
                self.device)  # original batch_size * num_rollout, max_num_actions
            next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)

            # range_arr = torch.arange(self.batch_size * self.num_rollouts).to(self.device)
            prev_relation = self.agent.dummy_start_label.to(self.device)  # original batch_size * num_rollout, 1-D, (1...)

            query_relation = entity_episode.get_query_relation()
            query_relation = torch.tensor(query_relation).long().to(self.device)
            current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)
            prev_gnn_state = self.agent.initial_gnn_state.repeat(len(current_entities), 1)
            #prev_entities = current_entities.clone()

            all_losses = []
            all_logits = []
            all_action_id = []
            path = [current_entities]
            
            self.agent.mgat(head, link, tail)
            for i in range(self.path_length):
                next_neighbors = deepcopy(self.train_environment.grapher.array_store)
                in_sample = next_neighbors[entity_episode.start_entities]
                end_match = in_sample[:, :, 0] == np.expand_dims(entity_episode.end_entities, -1)
                rel_match = in_sample[:, :, 1] == np.expand_dims(entity_episode.query_relation, -1)
                in_sample[end_match&rel_match] = 0
                next_neighbors[entity_episode.start_entities] = in_sample

                next_neighbors = next_neighbors[next_possible_entities.cpu().numpy()]
                next_neighbors = torch.LongTensor(next_neighbors).to(current_entities.device)

                loss, entity_state_emb, logits, idx, chosen_relation = self.agent.step(
                    next_possible_relations,
                    next_possible_entities, entity_state_emb,
                    prev_relation, query_relation,
                    current_entities,
                    next_neighbors, prev_gnn_state
                )

                entity_state = entity_episode(idx.cpu())
                next_possible_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)
                next_possible_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)
                current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)
                prev_relation = chosen_relation.to(self.device)

                all_losses.append(loss)
                all_logits.append(logits)
                all_action_id.append(idx)
                path.append(current_entities)

            rewards = entity_episode.get_reward()
            rewards = torch.tensor(rewards).to(self.device)

            cum_discounted_reward = self.calc_cum_discounted_reward(rewards)
            reinforce_loss = self.calc_reinforce_loss(all_losses, all_logits, cum_discounted_reward,
                                                        current_decay, self.baseline)

            self.baseline.update(torch.mean(cum_discounted_reward))

            self.optimizer.zero_grad()
            reinforce_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=self.grad_clip_norm, norm_type=2)
            self.optimizer.step()

            train_loss.append(reinforce_loss.detach().cpu().item())
            train_reward.append(rewards.cpu().float().tolist())

            if (self.batch_counter > 0)&(self.batch_counter % (self.eval_every//10) == 0):
                avg_loss = np.mean(train_loss[-(self.eval_every//10):])
                avg_reward = np.mean(sum(train_reward[-(self.eval_every//10):], []))
                print('Iteration: {}, Train loss: {:.4f}, rewards: {:.4f}'.format(self.batch_counter, avg_loss, avg_reward))
                gc.collect()

            if (self.batch_counter > 0)&(self.batch_counter % self.eval_every == 0):
                print('Eval:')
                self.test(beam = True)
                gc.collect()
                torch.cuda.empty_cache()
                print('------------------------------------------------------------')
            
            if self.batch_counter > self.total_iterations:
                break

    def test(self, beam=False, print_paths=False, save_model=True):

        with torch.no_grad():

            batch_counter = 0
            paths = defaultdict(list)
            answers = []
            all_final_reward_1 = 0
            all_final_reward_3 = 0
            all_final_reward_5 = 0
            all_final_reward_10 = 0
            all_final_reward_20 = 0
            auc = 0

            total_examples = self.test_environment.total_no_examples

            for entity_episode in self.test_environment.get_episodes():
                batch_counter += 1

                temp_batch_size = entity_episode.no_examples

                self.qr = entity_episode.get_query_relation()
                query_relation = self.qr
                query_relation = torch.tensor(query_relation).long().to(self.device)
                # set initial beam probs
                beam_probs = torch.zeros((temp_batch_size * self.test_rollouts, 1)).to(self.device)

                # get initial state for entity agent
                entity_state = entity_episode.get_state()
                
                head, link, tail = self.get_graph(entity_episode, mode = 'test')
                next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)
                next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)
                current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)

                entity_state_emb = torch.zeros(1, 2, temp_batch_size * self.test_rollouts,
                                               self.agent.m * self.hidden_size).to(self.device)
                prev_relation = (torch.ones(temp_batch_size * self.test_rollouts) * self.relation_vocab[
                    'DUMMY_START_RELATION']).long().to(self.device)
                prev_gnn_state = self.agent.initial_gnn_state.repeat(len(current_entities), 1)

                if print_paths:
                    self.entity_trajectory = [current_entities]
                    self.relation_trajectory = [prev_relation]   
                
                self.agent.mgat(head, link, tail)
                self.log_probs = np.zeros((temp_batch_size * self.test_rollouts,)) * 1.0
                for i in range(self.path_length):
                    
                    next_neighbors = self.test_environment.grapher.array_store[next_entities.cpu().numpy()]
                    next_neighbors = torch.LongTensor(next_neighbors).to(current_entities.device)
                    
                    loss, entity_state_emb, test_scores, test_action_idx, chosen_relation = self.agent.step(
                        next_relations,
                        next_entities, entity_state_emb,
                        prev_relation, query_relation,
                        current_entities,
                        next_neighbors, prev_gnn_state
                    )
                    
                    #Mimic original implementation on pytorch
                    if beam:
                        k = self.test_rollouts
                        beam_probs = beam_probs.to(self.device)
                        new_scores = test_scores + beam_probs
                        new_scores = new_scores.cpu()
                        if i == 0:
                            idx = np.argsort(new_scores)
                            idx = idx[:, -k:]
                            ranged_idx = np.tile([b for b in range(k)], temp_batch_size)
                            idx = idx[np.arange(k * temp_batch_size), ranged_idx]
                        else:
                            idx = self.top_k(new_scores, k)

                        y = idx // self.max_num_actions
                        x = idx % self.max_num_actions

                        y += np.repeat([b * k for b in range(temp_batch_size)], k)
                        entity_state['current_entities'] = entity_state['current_entities'][y]
                        entity_state['next_relations'] = entity_state['next_relations'][y, :]
                        entity_state['next_entities'] = entity_state['next_entities'][y, :]
                        entity_state_emb = entity_state_emb[:, :, y, :]

                        test_action_idx = x
                        chosen_relation = entity_state['next_relations'][np.arange(temp_batch_size * k), x]

                        beam_probs = new_scores[y, x]
                        beam_probs = beam_probs.reshape((-1, 1))

#                     #My implementation to fit arbitrary dimension
#                     if beam:
#                         k = self.test_rollouts
#                         beam_probs = beam_probs.to(self.device)
#                         new_scores = test_scores + beam_probs
#                         new_scores = new_scores.cpu()
#                         if i == 0:
#                             reshape_score = new_scores.reshape(temp_batch_size, self.test_rollouts, -1)
#                             possible_idx = []
#                             for x in reshape_score:
#                                 possible_idx.append(torch.LongTensor(np.where(x[0].cpu() > -1000)[0]))
#                             idx = []
#                             for x in possible_idx:
#                                 idx.append(torch.cat(([x]*(self.test_rollouts//len(x) + 1)))[:self.test_rollouts])
#                             idx = torch.cat(idx, 0)
#                         else:
#                             idx = self.top_k(new_scores, k)

#                         y = idx // self.max_num_actions
#                         x = idx % self.max_num_actions

#                         y += np.repeat([b * k for b in range(temp_batch_size)], k)
#                         entity_state['current_entities'] = entity_state['current_entities'][y]
#                         entity_state['next_relations'] = entity_state['next_relations'][y, :]
#                         entity_state['next_entities'] = entity_state['next_entities'][y, :]
#                         entity_state_emb = entity_state_emb[:, :, y, :]

#                         test_action_idx = x
#                         chosen_relation = entity_state['next_relations'][np.arange(temp_batch_size * k), x]
#                         beam_probs = new_scores[y, x]
#                         beam_probs = beam_probs.reshape((-1, 1))
                        
                        if print_paths:
                            for j in range(i):
                                self.entity_trajectory[j] = self.entity_trajectory[j][y]
                                self.relation_trajectory[j] = self.relation_trajectory[j][y]

                    entity_state = entity_episode(test_action_idx.cpu().numpy())
                    next_relations = torch.tensor(entity_state['next_relations']).long().to(self.device)
                    next_entities = torch.tensor(entity_state['next_entities']).long().to(self.device)
                    current_entities = torch.tensor(entity_state['current_entities']).long().to(self.device)
                    prev_relation = torch.tensor(chosen_relation).long().to(self.device)
                    
                    if print_paths:
                        self.entity_trajectory.append(entity_state['current_entities'])
                        self.relation_trajectory.append(chosen_relation)

                    test_scores = test_scores.cpu().numpy()
                    self.log_probs += test_scores[np.arange(self.log_probs.shape[0]), test_action_idx.cpu().numpy()]

                if beam:
                    self.log_probs = beam_probs

                rewards = entity_episode.get_reward()  # [B*test_rollouts]
                reward_reshape = np.reshape(rewards, (temp_batch_size, self.test_rollouts))  # [orig_batch, test_rollouts]
                self.log_probs = np.reshape(self.log_probs, (temp_batch_size, self.test_rollouts))
                sorted_indx = np.argsort(-self.log_probs)
                final_reward_1 = 0
                final_reward_3 = 0
                final_reward_5 = 0
                final_reward_10 = 0
                final_reward_20 = 0
                AP = 0
                ce = entity_episode.state['current_entities'].reshape((temp_batch_size, self.test_rollouts))
                se = entity_episode.start_entities.reshape((temp_batch_size, self.test_rollouts))
                for b in range(temp_batch_size):
                    answer_pos = None
                    seen = set()
                    pos=0
                    if self.pool == 'max':
                        for r in sorted_indx[b]:
                            if reward_reshape[b,r] == self.positive_reward:
                                answer_pos = pos
                                break
                            if ce[b, r] not in seen:
                                seen.add(ce[b, r])
                                pos += 1
                    if self.pool == 'sum':
                        scores = defaultdict(list)
                        answer = ''
                        for r in sorted_indx[b]:
                            scores[ce[b,r]].append(self.log_probs[b,r])
                            if reward_reshape[b,r] == self.positive_reward:
                                answer = ce[b,r]
                        final_scores = defaultdict(float)
                        for e in scores:
                            final_scores[e] = lse(scores[e])
                        sorted_answers = sorted(final_scores, key=final_scores.get, reverse=True)
                        if answer in  sorted_answers:
                            answer_pos = sorted_answers.index(answer)
                        else:
                            answer_pos = None


                    if answer_pos != None:
                        if answer_pos < 20:
                            final_reward_20 += 1
                            if answer_pos < 10:
                                final_reward_10 += 1
                                if answer_pos < 5:
                                    final_reward_5 += 1
                                    if answer_pos < 3:
                                        final_reward_3 += 1
                                        if answer_pos < 1:
                                            final_reward_1 += 1
                    if answer_pos == None:
                        AP += 0
                    else:
                        AP += 1.0/((answer_pos+1))
                    
                    if print_paths:
                        qr = self.train_environment.grapher.rev_relation_vocab[self.qr[b * self.test_rollouts]]
                        start_e = self.rev_entity_vocab[entity_episode.start_entities[b * self.test_rollouts]]
                        end_e = self.rev_entity_vocab[entity_episode.end_entities[b * self.test_rollouts]]
                        paths[str(qr)].append(str(start_e) + "\t" + str(end_e) + "\n")
                        paths[str(qr)].append("Reward:" + str(1 if answer_pos != None and answer_pos < 10 else 0) + "\n")
                        for r in sorted_indx[b]:
                            indx = b * self.test_rollouts + r
                            if rewards[indx] == self.positive_reward:
                                rev = 1
                            else:
                                rev = -1
                            answers.append(self.rev_entity_vocab[se[b,r].item()]+'\t'+ self.rev_entity_vocab[ce[b,r].item()]+'\t'+ str(self.log_probs[b,r])+'\n')
                            paths[str(qr)].append(
                                '\t'.join([str(self.rev_entity_vocab[e[indx].item()]) for e in
                                           self.entity_trajectory]) + '\n' + '\t'.join(
                                    [str(self.rev_relation_vocab[re[indx].item()]) for re in self.relation_trajectory]) + '\n' + str(
                                    rev) + '\n' + str(
                                    self.log_probs[b, r]) + '\n___' + '\n')
                        paths[str(qr)].append("#####################\n")

                all_final_reward_1 += final_reward_1
                all_final_reward_3 += final_reward_3
                all_final_reward_5 += final_reward_5
                all_final_reward_10 += final_reward_10
                all_final_reward_20 += final_reward_20
                auc += AP

            all_final_reward_1 /= total_examples
            all_final_reward_3 /= total_examples
            all_final_reward_5 /= total_examples
            all_final_reward_10 /= total_examples
            all_final_reward_20 /= total_examples
            auc /= total_examples
            
            if save_model:
                if all_final_reward_10 >= self.max_hits_at_10:
                    self.max_hits_at_10 = all_final_reward_10
                    torch.save(self.agent.state_dict(), self.model_dir + "agent" + '.ckpt')
                    # self.save_path = self.model_dir + "model" + '.ckpt'

            if print_paths:
                logger.info("[ printing paths at {} ]".format(self.output_dir + '/test_beam/'))
                for q in paths:
                    j = q.replace('/', '-')
                    with codecs.open(self.path_logger_file_ + '_' + j, 'a', 'utf-8') as pos_file:
                        for p in paths[q]:
                            pos_file.write(p)
                with open(self.path_logger_file_ + 'answers', 'w') as answer_file:
                    for a in answers:
                        answer_file.write(a)

            with open(self.output_dir + '/scores.txt', 'a') as score_file:
                score_file.write("Hits@1: {:.4f}".format(all_final_reward_1))
                score_file.write("\n")
                score_file.write("Hits@3: {:.4f}".format(all_final_reward_3))
                score_file.write("\n")
                score_file.write("Hits@5: {:.4f}".format(all_final_reward_5))
                score_file.write("\n")
                score_file.write("Hits@10: {:.4f}".format(all_final_reward_10))
                score_file.write("\n")
                score_file.write("Hits@20: {:.4f}".format(all_final_reward_20))
                score_file.write("\n")
                score_file.write("MRR: {:.4f}".format(auc))
                score_file.write("\n")
                score_file.write("------------------------------------")
            
            msg = "Hits@1: {:.4f}, Hits@3: {:.4f}, Hits@10: {:.4f}, MRR: {:.4f}".format(all_final_reward_1, 
                                                                                             all_final_reward_3,
                                                                                             all_final_reward_10, auc)
            print(msg)
            return msg
            
    def top_k(self, scores, k):
        scores = scores.reshape(-1, k * self.max_num_actions)  # [B, (k*max_num_actions)]
        idx = np.argsort(scores, axis=1)
        idx = idx[:, -k:]  # take the last k highest indices # [B , k]
        return idx.reshape((-1))