import logging
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from torch_scatter import scatter_add
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from omegaconf import DictConfig

from src.metrics import MRMetric, MRRMetric, HitsMetric
from .rspmm import generalized_rspmm


logger = logging.getLogger(__name__)

@rank_zero_only
def info(msg):
    logger.info(msg)
    
@rank_zero_only
def error(msg):
    logger.error(msg)
    

class Node4Reason:
    def __init__(self, 
                 batch_size: int, 
                 num_node: int,
                 x_node: torch.Tensor,
                 device: torch.device):
        self.batch_size = batch_size
        self.num_node = num_node
        self.device = device
    
        # batch_index = torch.arange(batch_size).to(self.device)

        self.x_node = torch.cat([x_node, torch.zeros_like(x_node)], dim=1)
    
        self.batch_index = torch.empty(0, dtype=torch.long, device=device)
        self.node_index = torch.empty(0, dtype=torch.long, device=device)
        self.node_key = torch.empty(0, dtype=torch.long, device=device)
        self.hidden = torch.empty(0, x_node.shape[1], dtype=torch.float, device=device)
        self.score = torch.empty(0, dtype=torch.float, device=device)
        # self.update_node(node_index, batch_index, label=True)
    
    def node_attribute(self, name, node_index, batch_index):
        attribute = getattr(self, name)
        value = torch.zeros(node_index.shape[0], *attribute.shape[1:], dtype=attribute.dtype, device=self.device)
        
        key = batch_index * self.num_node + node_index
        indice = torch.searchsorted(self.node_key, key)
        mask = indice < self.node_key.shape[0]
        value[mask] = attribute[indice[mask]]
        
        return value, mask
    
    def update_node(self, node_index, batch_index, input_layer, score_layer, label=False):
        input_key = batch_index * self.num_node + node_index
        new_key, inverse = torch.cat([self.node_key, input_key], dim=0).unique(return_inverse=True)
        
        new_batch_index = new_key // self.num_node
        new_node_index = new_key % self.num_node
        
        new_hidden = torch.empty(new_key.shape[0], *self.hidden.shape[1:], dtype=self.hidden.dtype, device=self.device)
        new_score = torch.empty(new_key.shape[0], dtype=self.score.dtype, device=self.device)
        
        mask = torch.isin(input_key, self.node_key)
        update_key = torch.unique(input_key[~mask])
        update_inverse = torch.searchsorted(new_key, update_key)
        
        new_hidden[inverse[:self.node_key.shape[0]]] = self.hidden
        update_hidden = self.x_node[update_key % self.num_node]
        if label is not False: update_hidden[:, update_hidden.size(1)//2:] = 1
        if label: update_hidden = torch.ones_like(update_hidden[:, :update_hidden.size(1)//2])
        else: update_hidden = torch.zeros_like(update_hidden[:, :update_hidden.size(1)//2])
        new_hidden[update_inverse] = update_hidden # input_layer(new_x_node)
        new_score[inverse[:self.node_key.shape[0]]] = self.score
        new_score[update_inverse] = score_layer(update_hidden).squeeze(1)
        
        self.batch_index = new_batch_index
        self.node_index = new_node_index
        self.node_key = new_key
        self.hidden = new_hidden
        self.score = new_score
        
    def compute_score(self, score_layer, node_index, batch_index):
        key = batch_index * self.num_node + node_index
        indice = torch.searchsorted(self.node_key, key)
        assert torch.all(indice < self.node_key.shape[0])
        self.score[indice] = score_layer(self.hidden[indice]).squeeze(1)
        
    def map_edge_index(self, edge_batch_index, edge_index):
        # should call this function after update_node
        map_edge_index = edge_index.clone()
        source_key = edge_batch_index * self.num_node + edge_index[:, 0]
        map_index = torch.searchsorted(self.node_key, source_key)
        assert torch.all(map_index < self.node_key.shape[0])
        map_edge_index[:, 0] = map_index
        
        target_key = edge_batch_index * self.num_node + edge_index[:, 2]
        map_index = torch.searchsorted(self.node_key, target_key)
        assert torch.all(map_index < self.node_key.shape[0])
        map_edge_index[:, 2] = map_index
        
        return map_edge_index
    

class Edge4Reason:
    def __init__(self, 
                 batch_size: int, 
                 num_node: int,
                 num_relation: int,
                 edge_index: torch.Tensor,
                 device: torch.device):
        self.batch_size = batch_size
        self.num_node = num_node
        self.num_relation = num_relation
        self.edge_index = edge_index
        self.device = device
      
        batch_index = torch.arange(batch_size).to(self.device)
        self.edge_indice = torch.arange(edge_index.shape[0]).to(self.device)
        self.edge_key = batch_index.unsqueeze(1) * num_node + edge_index[:, 0].unsqueeze(0)
        self.edge_key = einops.rearrange(self.edge_key, 'b n -> (b n)')
        self.edge_indice = einops.repeat(self.edge_indice, 'n -> (b n)', b=batch_size)
        
        self.all_num_neighbor = torch.bincount(edge_index[:, 0], minlength=num_node)
        
    def num_neighbor(self, node_index):
        return self.all_num_neighbor[node_index]
    
    def search_edge(self, source_node_index, batch_index):
        key = batch_index * self.num_node + source_node_index
        mask = torch.isin(self.edge_key, key)
        edge_key = self.edge_key[mask]
        edge_indice = self.edge_indice[mask]
        
        edge_batch_index = edge_key // self.num_node
        edge_index = self.edge_index[edge_indice]
        edge_index[:, 1] = edge_batch_index * self.num_relation + edge_index[:, 1]
        
        return edge_batch_index, edge_index
        
        


class GNNReasonLayer(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.mlp_out = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), 
                                     nn.ReLU(), 
                                     nn.Linear(self.hidden_dim, self.hidden_dim))
        self.alpha = nn.Parameter(torch.empty(1, self.hidden_dim))
        nn.init.normal_(self.alpha)
        self.norm = nn.LayerNorm(self.hidden_dim)
        
        self.relation_layer = nn.Linear(self.hidden_dim, self.hidden_dim*474)
        
    def forward(self, x_node, x_relation, edge_index):
        x_relation = einops.rearrange(self.relation_layer(x_relation), 'b (r d) -> r (b d)', d=self.hidden_dim)
        # the rspmm cuda kernel from torchdrug 
        # https://torchdrug.ai/docs/api/layers.html#torchdrug.layers.functional.generalized_rspmm
        # reduce memory complexity from O(|E|d) to O(|V|d)
        output = generalized_rspmm(einops.rearrange(edge_index[:, [0, 2]], 'n m -> m n'), 
                                   edge_index[:, 1], 
                                   torch.ones_like(edge_index[:, 0]).float(),
                                   relation=x_relation.float(), 
                                   input=x_node.float())
        # print(x_node.shape, output.shape)
        x_node = einops.rearrange(x_node, 'n (b d) -> b n d', d=self.hidden_dim)
        output = einops.rearrange(output, 'n (b d) -> b n d', d=self.hidden_dim)
        x = self.mlp_out(output + self.alpha * x_node)
        x = self.norm(x)
        x = x + x_node
        x = einops.rearrange(x, 'b n d -> n (b d)')
        return x
    

class GNNReasonModel(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 num_layer: int,
                 remove_one_hop: bool):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layer = num_layer
        self.remove_one_hop = remove_one_hop
        
        self.node_ratio = 0.1
        self.degree_ratio = 1.0
        
        self.layers = nn.ModuleList([GNNReasonLayer(hidden_dim) for _ in range(num_layer)])
        self.input_layer = nn.Sequential(nn.Linear(self.hidden_dim*2, self.hidden_dim), 
                                         nn.ReLU(),
                                         nn.Linear(self.hidden_dim, self.hidden_dim))
        self.realtion_input_layer = nn.Sequential(nn.Linear(self.hidden_dim*2, self.hidden_dim), 
                                                  nn.ReLU(),
                                                  nn.Linear(self.hidden_dim, self.hidden_dim))
        self.score_layer = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim//2), 
                                         nn.ReLU(),
                                         nn.Linear(self.hidden_dim//2, 1))
        
        self.relation_embedding = nn.Embedding(474, self.hidden_dim)
        
    @classmethod
    def from_config(cls, config: DictConfig):
        hidden_dim = config.gnn.model.hidden_dim
        num_layer = config.gnn.model.num_layer
        remove_one_hop = config.data.remove_one_hop
        return cls(
            hidden_dim,
            num_layer,
            remove_one_hop
        )
    
    def _remove_edge(self, edge_index, remove_edge_index):   
        h, r, t = edge_index.chunk(3, dim=1)
        _h_remove, _r_remove, _t_remove = remove_edge_index.chunk(3, dim=1)
        max_r = r.max() + 1  
        max_n = max(h.max(), t.max()) + 1
        h_remove = torch.cat([_h_remove, _t_remove], dim=0)  
        r_remove = torch.cat([_r_remove, torch.where(_r_remove >= max_r//2, _r_remove - max_r//2, _r_remove + max_r//2)], dim=0)
        t_remove = torch.cat([_t_remove, _h_remove], dim=0)        
        if self.remove_one_hop:
            key_fn = lambda x, y: x + y * max_n
            mask = ~torch.isin(key_fn(h, t), key_fn(h_remove, t_remove))
        else:
            key_fn = lambda x, y, z: x + (y + z * max_n) * max_n
            mask = ~torch.isin(key_fn(h, t, r), key_fn(h_remove, t_remove, r_remove))
        mask = mask.squeeze(1)
        return edge_index[mask]
    
    def _scatter_topk(self, x, index, k):
        # group sort input x with respect to index
        sorted_x_indices = torch.argsort(x, descending=True)
        indexed_indices = torch.argsort(index[sorted_x_indices], stable=True)
        sorted_x_indices = sorted_x_indices[indexed_indices]
        sorted_x = x[indexed_indices]
        
        _, sizes = torch.unique(index, return_counts=True)
        
        # initialize indices, ((0, 1, 2), (0, 1, 2))
        select_indices = torch.arange(sizes.max()).to(x.device)
        select_indices = einops.repeat(select_indices, 'k -> b k', b=k.size(0))

        # compute offset and offset to indices
        offset = torch.cat([torch.tensor([0]).to(sizes.device), sizes]).cumsum(0)[:-1]
        sizes_need_to_select = torch.minimum(sizes, k)
        select_indices_mask = select_indices < sizes_need_to_select.unsqueeze(1)
        select_indices = select_indices + offset.unsqueeze(1)
        select_indices = select_indices[select_indices_mask]

        return sorted_x[select_indices], sorted_x_indices[select_indices]
    
    def forward(self, h, r, x_node, x_relation, edge_index, remove_edge_index=None):
        if remove_edge_index is not None:
            # print(edge_index.shape)
            edge_index = self._remove_edge(edge_index, remove_edge_index)
            # print(edge_index.shape)
        batch_size = h.size(0)
        device = h.device
        num_node = x_node.size(0)
        num_relation = x_relation.size(0)
        num_edge = edge_index.size(0)
        
        # x_node = torch.zeros(batch_size, num_node, self.hidden_dim, device=device)
        # x_node[torch.arange(batch_size, device=device), h] = 1
        # x = einops.rearrange(x_node, 'b n d -> n (b d)')
        # x_relation = self.relation_embedding(r)
        
        
        x = einops.repeat(x_node, 'n d -> b n d', b=batch_size)
        x = torch.cat([x, torch.zeros_like(x)], dim=2)
        x[torch.arange(batch_size, device=device), h, x.size(2)//2:] = 1
        x = self.input_layer(x)
        x = einops.rearrange(x, 'b n d -> n (b d)')
        # x_relation = einops.repeat(x_relation, 'n d -> b n d', b=batch_size)
        x_relation = x_relation[r]
        # x_relation = einops.repeat(x_relation, 'n d -> b n d', b=batch_size)
        # x_relation = torch.cat([x_relation, torch.zeros_like(x_relation)], dim=2)
        # x_relation[torch.arange(batch_size, device=device), r] = 1
        # x_relation = self.realtion_input_layer(x_relation)
        # x_relation = einops.rearrange(x_relation, 'b n d -> n (b d)')
        
        # x = torch.zeros(batch_size, num_node, self.hidden_dim, device=device)
        # for layer_index, layer in enumerate(self._layers):
        #     x = layer(
        #         h,
        #         r,
        #         x,
        #         x_relation,
        #         edge_index,
        #     )
            
        # x = x_node
        for layer_index, layer in enumerate(self.layers):
            x = layer(
                x,
                x_relation,
                edge_index,
            )
        x = einops.rearrange(x, 'n (b d) -> b n d', b=batch_size)
        return self.score_layer(x).squeeze(2)
        
        batch_index = torch.arange(batch_size, device=device)
        
        x_relation = einops.repeat(x_relation, 'n d -> b n d', b=batch_size)
        x_relation_label = torch.zeros_like(x_relation)
        x_relation_label[batch_index, r] = 1
        x_relation = torch.cat([x_relation, x_relation_label], dim=2)
        x_relation = self.realtion_input_layer(x_relation)
        x_relation = einops.rearrange(x_relation, 'b n d -> (b n) d')
        
        node4reason = Node4Reason(batch_size, num_edge, x_node, device)
        edge4reason = Edge4Reason(batch_size, num_node, num_relation, edge_index, device)
        
        node4reason.update_node(h, batch_index, self.input_layer, self.score_layer, label=True)
        
        for layer_index, layer in enumerate(self.layers):
            # ---------------------------------------------------------------------------------------------------------------------------
            # select important nodes for each batch
            # ---------------------------------------------------------------------------------------------------------------------------
            num_node_need_to_select = torch.tensor([num_node * self.node_ratio] * batch_size, dtype=torch.long, device=device)
            
            _, node_topk_indice = self._scatter_topk(node4reason.score, node4reason.batch_index, num_node_need_to_select)
            
            # number of selected node for each batch
            num_selected_node = torch.bincount(node4reason.batch_index[node_topk_indice], minlength=batch_size)
            assert torch.all(num_selected_node > 0)

            # ---------------------------------------------------------------------------------------------------------------------------
            # select important edges for each batch
            # ---------------------------------------------------------------------------------------------------------------------------
            num_edge_need_to_select = (self.degree_ratio * num_node_need_to_select * num_edge / num_node).long()
            
            # compute number of edge for each source node
            num_neighbor_edge = edge4reason.num_neighbor(node4reason.node_index[node_topk_indice])
            
            # number of edge for each batch
            num_batch_edge = scatter_add(num_neighbor_edge, node4reason.batch_index[node_topk_indice], dim_size=batch_size)
            num_batch_edge_mean = num_batch_edge.float().mean().clamp(min=1)
            
            # chunk batch to reduce peak memory usage
            chunk_size = max(int(1e7 / num_batch_edge_mean), 1)
            chunked_num_selected_node = num_selected_node.split(chunk_size, dim=0)
            chunked_num_edge_need_to_select = num_edge_need_to_select.split(chunk_size, dim=0)
            chunked_selected_node_index = node4reason.node_index[node_topk_indice].split(
                [chunk.sum() for chunk in chunked_num_selected_node], dim=0
            )
            chunked_selected_batch_index = node4reason.batch_index[node_topk_indice].split(
                [chunk.sum() for chunk in chunked_num_selected_node], dim=0
            )
            
            selected_edge_batch_index = list()
            selected_edge_index = list()
            for i in range(len(chunked_num_selected_node)):
                current_edge_batch_index, current_edge_index = edge4reason.search_edge(
                    chunked_selected_node_index[i],
                    chunked_selected_batch_index[i]
                )
                
                if current_edge_batch_index.numel() == 0:
                    continue
                
                # select topk edges with respect to source node
                current_edge_score, mask = node4reason.node_attribute('score', current_edge_index[:, 2], current_edge_batch_index)
                current_edge_score[mask] = F.sigmoid(current_edge_score[mask])
                # current_edge_score[~mask] = torch.full_like(current_edge_score[~mask], -1e5)
                _, edge_topk_indices = self._scatter_topk(
                    current_edge_score,
                    current_edge_batch_index,
                    num_edge_need_to_select[torch.unique(current_edge_batch_index)],
                )
                selected_edge_batch_index.append(current_edge_batch_index[edge_topk_indices])
                selected_edge_index.append(current_edge_index[edge_topk_indices])
            
            selected_edge_batch_index = torch.cat(selected_edge_batch_index, dim=0)
            selected_edge_index = torch.cat(selected_edge_index, dim=0)

            node4reason.update_node(
                torch.cat([selected_edge_index[:, 0], selected_edge_index[:, 2]]),
                torch.cat([selected_edge_batch_index, selected_edge_batch_index]),
                self.input_layer, 
                self.score_layer
            )
            
            layer_input = F.sigmoid(node4reason.score).unsqueeze(1) * node4reason.hidden
            layer_edge_index = node4reason.map_edge_index(selected_edge_batch_index, selected_edge_index)
            layer_out = layer(
                layer_input,
                # x_relation,
                einops.rearrange(self.relation_linear(self.relation_embedding(r)), 'b (r d) -> (b r) d', r=474),
                layer_edge_index
            )
            
            node4reason.hidden = layer_out
            node4reason.compute_score(self.score_layer, selected_edge_index[:, 2], selected_edge_batch_index)
            
        score = torch.zeros(batch_size, num_node, device=device)
        score[node4reason.batch_index, node4reason.node_index] = node4reason.score
       
        return score
            
        
    def _forward(self, h, r, x_node, x_relation, edge_index, remove_edge_index=None):
        if remove_edge_index is not None:
            edge_index = self._remove_edge(edge_index, remove_edge_index)
        
        batch_size = h.size(0)
        num_node = x_node.size(0)
        num_relation = x_relation.size(0)
        num_edge = edge_index.size(0)
        num_neighbor = scatter_add(torch.ones_like(edge_index[:, 0]), edge_index[:, 0], dim=0, dim_size=num_node)
        batch_index = torch.arange(batch_size).to(h.device)
        source, relation, target = map(lambda x: x.squeeze(1), edge_index.chunk(3, dim=1))
        
        expanded_source = einops.repeat(source, 'n -> (b n)', b=batch_size)
        expanded_relation = einops.repeat(relation, 'n -> (b n)', b=batch_size)
        expanded_target = einops.repeat(target, 'n -> (b n)', b=batch_size)
        expanded_batch_index = einops.repeat(batch_index, 'b -> (b n)', n=num_edge)
        
        encoded_edge = einops.repeat(batch_index, 'b -> (b n)', n=num_edge) * num_node + expanded_source
        
        x = torch.cat([x_node[h], torch.ones_like(x_node[h])], dim=1)
        x = self.input_layer(x)
        # x = einops.repeat(x, 'n d -> b n d', b=batch_size).clone()
        # x = torch.zeros_like(x)
        # label source node
        # x[batch_index, h] = x[batch_index, h] + torch.ones_like(x[batch_index, h])
        
        # x = einops.rearrange(x, 'b n d -> (b n) d')
        x_relation = einops.repeat(x_relation, 'n d -> (b n) d', b=batch_size).clone()
        
        # node index composed of batch index and node index of each batch
        node = torch.stack([batch_index, h], dim=1)
        visit_mask = torch.zeros(batch_size, num_node).bool().to(h.device)
        visit_mask[batch_index, h] = True
        
        # score for each node of each bach
        score = torch.zeros(batch_size, num_node).to(h.device)
        score[batch_index, h] = self.score_layer(x).squeeze(1)
        
        for layer in self.layers:
            # print(node.shape)
            # ---------------------------------------------------------------------------------------------------------------------------
            # select important nodes for each batch
            # ---------------------------------------------------------------------------------------------------------------------------
            batch_index_of_node, node_index = map(lambda x: x.squeeze(-1), visit_mask.nonzero().chunk(2, dim=1))
            # num_node_per_batch = scatter_add(torch.ones_like(batch_index_of_node), batch_index_of_node, dim=0, dim_size=batch_size)
            num_node_need_to_select = torch.tensor([num_node * self.node_ratio]).to(x.device)
            num_node_need_to_select = einops.repeat(num_node_need_to_select, 'n -> (n m)', m=batch_size).long()
            
            # sizes = torch.bincount(batch_index_of_node, minlength=batch_size)
            # select_indices = torch.arange(sizes.max()).to(x.device)
            # select_indices = einops.repeat(select_indices, 'k -> b k', b=sizes.size(0))
            # print(select_indices)
            _, topk_indices = self._scatter_topk(score[batch_index_of_node, node_index], batch_index_of_node, num_node_need_to_select)
            selected_node = visit_mask.nonzero()[topk_indices]
            
            # number of selected node for each batch
            real_num_node_selected = torch.bincount(selected_node[:, 0], minlength=batch_size)
            
            # ---------------------------------------------------------------------------------------------------------------------------
            # select important edges for each batch
            # ---------------------------------------------------------------------------------------------------------------------------
            num_target_node_need_to_select = (self.degree_ratio * num_node_need_to_select * num_edge / num_node).long()
            # num_target_node_need_to_select = num_target_node_need_to_select[selected_node[:, 0]]
            
            # compute number of targer nodes for each source node
            num_target_node = num_neighbor[selected_node[:, 1]]
            
            # number of selected edge for each batch
            num_selected_edge = scatter_add(num_target_node, selected_node[:, 0], dim_size=batch_size)

            num_selected_edge_mean = num_selected_edge.float().mean().clamp(min=1)
            chunk_size = max(int(1e6 / num_selected_edge_mean), 1)
            
            # chunk batch to reduce peak memory usage
            chunked_rel_num_node_selected = real_num_node_selected.split(chunk_size, dim=0)
            chunked_num_target_node_need_to_select = num_target_node_need_to_select.split(chunk_size, dim=0)
            chunked_selected_node = selected_node.split([chunk.sum() for chunk in chunked_rel_num_node_selected], dim=0)
            # print(real_num_node_selected)
            # # chunk batch to reduce peak memory usage
            # num_target_node_mean = num_target_node.float().mean().clamp(min=1)
            # chunk_size = max(int(1e7 / num_target_node_mean), 1)
            
            # chunked_rel_num_node_selected = real_num_node_selected.split(chunk_size, dim=0)
            # chunked_num_target_node = num_target_node.split(chunk_size, dim=0)
            # chunked_num_target_node_need_to_select = num_target_node_need_to_select.split(chunk_size, dim=0)
            # chunked_selected_node = selected_node.split(chunk_size, dim=0)# [chunk.sum() for chunk in chunked_rel_num_node_selected])
            
            selected_edge = list()
            for i in range(len(chunked_selected_node)):
                current_chunked_selected_node = chunked_selected_node[i]
                
                # obtain initial mask of edges for each batch
                encoded_selected_node = current_chunked_selected_node[:, 0] * num_node + current_chunked_selected_node[:, 1]
                edge_mask = torch.isin(encoded_edge, encoded_selected_node)
            
                # select initial edge
                _selected_source_node = expanded_source[edge_mask]
                _selected_target_node = expanded_target[edge_mask]
                _selected_relation = expanded_relation[edge_mask]
                _selected_batch_index = expanded_batch_index[edge_mask]
                
                # select topk edges with respect to source node
                _selected_edge_score = score[_selected_batch_index, _selected_target_node]
                _value, _indice = torch.unique(_selected_batch_index * num_node + _selected_source_node, sorted=True, return_inverse=True)
                _, topk_indices = self._scatter_topk(_selected_edge_score, 
                                                     _indice, 
                                                     num_target_node_need_to_select[_value//num_node])
                
                _selected_edge = torch.stack([
                    _selected_batch_index[topk_indices],
                    _selected_source_node[topk_indices],
                    _selected_relation[topk_indices],
                    _selected_target_node[topk_indices]
                ], dim=1)
                selected_edge.append(_selected_edge)
            
            selected_edge = torch.cat(selected_edge, dim=0)
            selected_batch_index, selected_source_node, selected_relation, selected_target_node = \
                map(lambda x: x.squeeze(-1), selected_edge.chunk(4,  dim=1))
            selected_edge_index = torch.stack([
                selected_batch_index * num_node + selected_source_node,
                selected_batch_index * num_relation + selected_relation,
                selected_batch_index * num_node + selected_target_node,
            ], dim=1)
            
            # reindex for source and target
            value, indice = torch.unique(selected_edge_index[:, [0, 2]].view(-1), sorted=True, return_inverse=True)
            previous_node_to_indice = torch.searchsorted(value, node[:, 0] * num_node + node[:, 1])
            previous_x_mask = previous_node_to_indice < value.size(0)
            previous_node_to_indice = previous_node_to_indice[previous_x_mask]
            update_node_to_indice = torch.searchsorted(value, selected_batch_index * num_node + selected_target_node)
            
            # get input of this layer, first we get input from text representation, then update the previous visited node representation
            layer_input = torch.cat([x_node[value % num_node], torch.zeros_like(x_node[value % num_node])], dim=1)
            layer_input = self.input_layer(layer_input)
            layer_input[previous_node_to_indice] = x[previous_x_mask]
            # x = layer(
            #     x, 
            #     x_relation, 
            #     selected_edge_index
            # )
            
            # scaling input by score
            # layer_input[previous_node_to_indice] = F.sigmoid(score[node[:, 0][previous_x_mask], node[:, 1][previous_x_mask]]).unsqueeze(1) * \
            #     layer_input[previous_node_to_indice]
            # score_mask = visit_mask[s]
            layer_input = F.sigmoid(score[value // num_node, value % num_node]).unsqueeze(1) * layer_input
            # update edge_index
            selected_edge_index[:, 0], selected_edge_index[:, 2] = indice[::2], indice[1::2]
            layer_out = layer(
                layer_input, 
                x_relation, 
                selected_edge_index
            )
            
            layer_out[previous_node_to_indice] = layer_out[previous_node_to_indice] + x[previous_x_mask]
            x = layer_out
            score[selected_batch_index, selected_target_node] = self.score_layer(x[update_node_to_indice]).squeeze(1)
            visit_mask[value // num_node, value % num_node] = True
            node = torch.stack([value // num_node, value % num_node], dim=1)
            # score[selected_batch_index, selected_target_node]
            
            # score[selected_batch_index, selected_target_node] = F.sigmoid(self.score_layer(
            #     einops.rearrange(x, '(b n) d -> b n d', b=batch_size)[selected_batch_index, selected_target_node]
            # ).squeeze(1))
            # updated_node = torch.stack([selected_batch_index, selected_target_node], dim=1)
            # node = torch.cat([node, updated_node], dim=0)
            # node = (score > 0).nonzero()
        # exit()
        
        return score
    

class CoSTGNNLightningModule(pl.LightningModule):
    def __init__(self,
                 config: DictConfig,
                 model: GNNReasonModel,
                 pretrain: bool):
        super().__init__()
        self.config = config
        self.model = model
        self.pretrain = pretrain
        
        self._node_embedding_path = None
        self._relation_embedding_path = None
        self._inductive_node_embedding_path = None
        self._inductive_relation_embedding_path = None
        self._pseudo_fact_path = None
        
        self.mr_fn = MRMetric()
        self.mrr_fn = MRRMetric()
        self.hits1_fn = HitsMetric(topk=1)
        self.hits3_fn = HitsMetric(topk=3)
        self.hits10_fn = HitsMetric(topk=10)
        self.hits50_fn = HitsMetric(topk=50)
        self.hits100_fn = HitsMetric(topk=100)
        
    @classmethod
    def from_config(cls, config: DictConfig, pretrain: bool):
        model = GNNReasonModel.from_config(config)
        return cls(config, model, pretrain)
        
    def set_node_embedding_path(self, path):
        self._node_embedding_path = path
        
    def set_relation_embedding_path(self, path):
        self._relation_embedding_path = path
    
    def set_inductive_node_embedding_path(self, path):
        self._inductive_node_embedding_path = path
        
    def set_inductive_relation_embedding_path(self, path):
        self._inductive_relation_embedding_path = path
        
    def set_pseudo_fact_path(self, path):
        self._pseudo_fact_path = path
        
    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        grouped_optimizer_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters() if any([d in n for d in no_decay]) and p.requires_grad],
                'weight_decay': 0.0
            },
            {
                'params': [p for n, p in self.model.named_parameters() if not any([d in n for d in no_decay]) and p.requires_grad],
                'weight_decay': self.config.gnn.weight_decay
            }
        ]
        optimizer = torch.optim.AdamW(
            grouped_optimizer_parameters,
            lr=self.config.gnn.lr,
        )

        # scheduler = get_constant_scheduler(optimizer)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 15], 0.1)
        scheduler = {
            'scheduler': scheduler, 
            'interval': 'epoch', 
            'frequency': 1
        }

        return [optimizer], [scheduler]
    
    def on_fit_start(self):
        if getattr(self, 'node_embedding', None) is None and self._node_embedding_path is not None:
            self.node_embedding = torch.load(
                self._node_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
        if getattr(self, 'relation_embedding', None) is None and self._relation_embedding_path is not None:
            self.relation_embedding = torch.load(
                self._relation_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
        if getattr(self, 'pseudo_fact', None) is None and self._pseudo_fact_path is not None:
            self.pseudo_fact = torch.load(self._pseudo_fact_path, map_location=lambda storage, loc: storage)
    
    def training_step(self, batch):
        # construct tail entities
        if self.pretrain:
            negative_prob = 1 - batch.filter_mask.float()
            negative_index = torch.multinomial(negative_prob, 
                                               min(batch.filter_mask.size(1), 2**self.config.gnn.num_negative_example), 
                                               replacement=True)
        
            x_node, x_relation = self.node_embedding, self.relation_embedding
            scores = self.model(batch.h, 
                            batch.r,
                            x_node, 
                            x_relation, 
                            batch.edge_index, 
                            torch.stack([batch.h, batch.r, batch.t], dim=1))

            index = torch.cat([batch.t.unsqueeze(1), negative_index], dim=1)

            logits = torch.gather(scores, 1, index)
            target = torch.zeros_like(logits)
            target[:, 0] = 1
            loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
            weights = torch.ones_like(logits)
            with torch.no_grad():
                weights[:, 1:] = F.softmax(logits[:, 1:].detach()/self.config.gnn.adversarial_temperature, dim=-1)
            loss = (loss * weights).sum(1) / weights.sum(1)
            loss = loss.mean()
        else:
            # compute pseudp target
            x_node, x_relation = self.node_embedding, self.relation_embedding
            scores = self.model(batch.h, 
                                batch.r,
                                x_node, 
                                x_relation, 
                                batch.edge_index, 
                                torch.stack([batch.h, batch.r, batch.t], dim=1))
            
            num_node = batch.filter_mask.size(1)
            pseudo_target = torch.zeros_like(batch.filter_mask)
            for i, (h, r) in enumerate(zip(batch.h, batch.r)):
                pseudo_key = self.pseudo_fact[:, 1] * num_node + self.pseudo_fact[:, 0]
                key = r.item() * num_node + h.item()
                pseudo_target[i, self.pseudo_fact[pseudo_key == key, 2]] = 1
                pseudo_target[i, batch.t[i].item()] = 1
                
            pseudo_t = torch.multinomial(pseudo_target.float(), 1).to(self.device)
            t = torch.cat([batch.t.unsqueeze(1), pseudo_t], dim=1)
            pseudo_target = pseudo_target.to(self.device).bool() | batch.filter_mask.bool()
            negative_prob = 1 - pseudo_target.float()
            negative_prob = einops.repeat(negative_prob, 'b n -> (b m) n', m=t.size(1))
            negative_index = torch.multinomial(negative_prob, 
                                               min(batch.filter_mask.size(1), 2**self.config.gnn.num_negative_example), 
                                               replacement=True)
                
            index = torch.cat([einops.rearrange(t, 'b (m n) -> (b m) n', n=1), negative_index], dim=1)

            scores = einops.repeat(scores, 'b n -> (b m) n', m=t.size(1))
            logits = torch.gather(scores, 1, index)
            target = torch.zeros_like(logits)
            target[:, 0] = 1
            loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
            weights = torch.ones_like(logits)
            with torch.no_grad():
                weights[:, 1:] = F.softmax(logits[:, 1:].detach()/self.config.gnn.adversarial_temperature, dim=-1)
            loss = (loss * weights).sum(1) / weights.sum(1)
            loss = loss.mean()

        self.log('loss', loss.detach(), prog_bar=True)
        self.log('memory', torch.cuda.max_memory_allocated()/(1024**3), prog_bar=True)
        return loss
    
    def _evaluate(self, batch, mode='val'):
        is_inductive = getattr(self.trainer, f'{mode}_dataloaders').dataset.split_data_dict['inductive']
        if is_inductive:
            x_node, x_relation = self.inductive_node_embedding, self.inductive_relation_embedding
        else:
            x_node, x_relation = self.node_embedding, self.relation_embedding

        all_scores = self.model(batch.h, 
                                batch.r,
                                x_node, 
                                x_relation, 
                                batch.edge_index)
        
        filter_mask = batch.filter_mask
        if batch.negative_target_node is not None:
            filter_mask = torch.ones_like(filter_mask)
            filter_mask = filter_mask.scatter(1, batch.negative_target_node, 0).bool()
        
        # compute ranks for each answer nodes
        answer_nodes = torch.stack([torch.arange(batch.h.size(0), device=self.device), batch.t], dim=1)
        answer_scores = all_scores[answer_nodes[:, 0], answer_nodes[:, 1]]
        expanded_filter_mask = filter_mask[answer_nodes[:, 0]].bool()
        batch_all_scores = all_scores[answer_nodes[:, 0]]
        all_ranks = torch.sum((batch_all_scores >= answer_scores.unsqueeze(1)) & (~expanded_filter_mask), dim=1) + 1
        
        self.mr_fn.update(all_ranks)
        self.mrr_fn.update(all_ranks)
        self.hits1_fn.update(all_ranks)
        self.hits3_fn.update(all_ranks)
        self.hits10_fn.update(all_ranks)
        self.hits50_fn.update(all_ranks)
        self.hits100_fn.update(all_ranks)
        
        return all_scores, all_ranks
    
    def _compute_metrics(self, mode='valid'):
        mr = self.mr_fn.compute()
        mrr = self.mrr_fn.compute()
        hits1 = self.hits1_fn.compute()
        hits3 = self.hits3_fn.compute()
        hits10 = self.hits10_fn.compute()
        hits50 = self.hits50_fn.compute()
        hits100 = self.hits100_fn.compute()

        self.mr_fn.reset()
        self.mrr_fn.reset()
        self.hits1_fn.reset()
        self.hits3_fn.reset()
        self.hits10_fn.reset()
        self.hits50_fn.reset()
        self.hits100_fn.reset()

        self.log(f'{mode}_mr', mr, prog_bar=True, sync_dist=True)
        self.log(f'{mode}_mrr', mrr, prog_bar=True, sync_dist=True)
        self.log(f'{mode}_hits1', hits1, prog_bar=True, sync_dist=True)
        self.log(f'{mode}_hits3', hits3, prog_bar=False, sync_dist=True)
        self.log(f'{mode}_hits10', hits10, prog_bar=True, sync_dist=True)
        self.log(f'{mode}_hits50', hits3, prog_bar=False, sync_dist=True)
        self.log(f'{mode}_hits100', hits10, prog_bar=False, sync_dist=True)
        
        return mr, mrr, hits1, hits3, hits10, hits50, hits100
    
    def on_validation_start(self):
        is_inductive = self.trainer.val_dataloaders.dataset.split_data_dict['inductive']
        if is_inductive:
            if getattr(self, 'inductive_node_embedding', None) is None:
                self.inductive_node_embedding = torch.load(
                    self._inductive_node_embedding_path, 
                    map_location=lambda storage, loc: storage
                ).to(self.device)
                self.inductive_relation_embedding = torch.load(
                    self._inductive_relation_embedding_path, 
                    map_location=lambda storage, loc: storage
                ).to(self.device)
        else:
            if getattr(self, 'node_embedding', None) is None:
                self.node_embedding = torch.load(
                    self._node_embedding_path, 
                    map_location=lambda storage, loc: storage
                ).to(self.device)
                self.relation_embedding = torch.load(
                    self._relation_embedding_path, 
                    map_location=lambda storage, loc: storage
                ).to(self.device)

    def validation_step(self, batch):
        self._evaluate(batch, mode='val')
    
    def on_validation_epoch_end(self):
        from tabulate import tabulate
        mr, mrr, hits1, hits3, hits10, hits50, hits100 = self._compute_metrics(mode='valid')
        metrics = {
            'mr': [mr],
            'mrr': [mrr],
            'hits1': [hits1],
            'hits3': [hits3],
            'hits10': [hits10],
            'hits50': [hits50],
            'hits100': [hits100]
        }
        info(f'Valid Metircs at Epoch {self.trainer.current_epoch}: \n' + tabulate(metrics, headers='keys', tablefmt='grid'))
    
    def on_test_start(self):
        is_inductive = self.trainer.test_dataloaders.dataset.split_data_dict['inductive']
        if is_inductive:
            if getattr(self, 'inductive_node_embedding', None) is None:
                self.inductive_node_embedding = torch.load(
                    self._inductive_node_embedding_path, 
                    map_location=lambda storage, loc: storage
                ).to(self.device)
                self.inductive_relation_embedding = torch.load(
                    self._inductive_relation_embedding_path, 
                    map_location=lambda storage, loc: storage
                ).to(self.device)
        else:
            if getattr(self, 'node_embedding', None) is None:
                self.node_embedding = torch.load(
                    self._node_embedding_path, 
                    map_location=lambda storage, loc: storage
                ).to(self.device)
                self.relation_embedding = torch.load(
                    self._relation_embedding_path, 
                    map_location=lambda storage, loc: storage
                ).to(self.device)

    def test_step(self, batch):
        self._evaluate(batch, mode='test')
        
    def on_test_epoch_end(self):
        from tabulate import tabulate
        mr, mrr, hits1, hits3, hits10, hits50, hits100 = self._compute_metrics(mode='test')
        metrics = {
            'mr': [mr],
            'mrr': [mrr],
            'hits1': [hits1],
            'hits3': [hits3],
            'hits10': [hits10],
            'hits50': [hits50],
            'hits100': [hits100]
        }
        info('Test Metircs: \n' + tabulate(metrics, headers='keys', tablefmt='grid'))
        
    def on_predict_start(self):
        if getattr(self, 'node_embedding', None) is None:
            self.node_embedding = torch.load(
                self._node_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
            self.relation_embedding = torch.load(
                self._relation_embedding_path, 
                map_location=lambda storage, loc: storage
            ).to(self.device)
    
    def predict_step(self, batch):
        x_node, x_relation = self.node_embedding, self.relation_embedding
        
        h, r, edge_index = batch

        all_scores = self.model(h, 
                                r,
                                x_node, 
                                x_relation, 
                                edge_index)
        
        all_scores = F.softmax(all_scores, dim=-1)
        pseudo_target = torch.bernoulli(all_scores).long()
        pseudo_target = pseudo_target * (all_scores > 0.75)
        
        query_index = torch.stack([h, r], dim=1)
        pseudo_t = pseudo_target.nonzero()
        pseudo_fact = torch.cat([query_index[pseudo_t[:, 0]], pseudo_t[:, 1:]], dim=1)
        
        # num_node = batch.filter_mask.size(1)
        
        return pseudo_fact