import torch
import torch.nn.functional as F

from torch_sparse import SparseTensor

from modules.neighbor_loader import LastNeighborLoader

class WeightedTNCNLinkPred(torch.nn.Module):
    def __init__(self, memory, gnn, link_pred, neighbor_loader, ncn_mode, hop_num, device):
        super().__init__()
        self.memory = memory
        self.gnn = gnn
        self.link_pred = link_pred
        self.neighbor_loader = neighbor_loader
        self.ncn_mode = ncn_mode
        self.hop_num = hop_num
        self.device=device
    
    def forward(self, data, edge_index, edge_weight, assoc):
        
        src = edge_index[0]
        dst = edge_index[1]
        device = self.device 

        n_id = torch.cat([src, dst]).unique()
        # n_id, edge_index, e_id = neighbor_loader(n_id)
        n_id, n_edge_index, n_e_id = WeightedTNCNLinkPred.find_neighbor(self.neighbor_loader, n_id, self.hop_num)
        assoc[n_id] = torch.arange(n_id.size(0), device=device) 
        
        id_num = n_id.size(0)

        if self.ncn_mode == 0:
            adj_0_1 = WeightedTNCNLinkPred.generate_adj_0_1_hop(id_num, n_edge_index,self.device)
            adj_1 = WeightedTNCNLinkPred.generate_adj_1_hop(id_num, n_edge_index,self.device)
            adjs = (adj_0_1, adj_1)
        elif self.ncn_mode == 1:
            adj_1 = WeightedTNCNLinkPred.generate_adj_1_hop(id_num, n_edge_index,self.device)
            adjs = (adj_1)
        elif self.ncn_mode == 2:
            adj_0_1 = WeightedTNCNLinkPred.generate_adj_0_1_hop(id_num, n_edge_index,self.device)
            adj_1 = WeightedTNCNLinkPred.generate_adj_1_hop(id_num, n_edge_index,self.device)
            adj_0_1_2 = WeightedTNCNLinkPred.generate_adj_0_1_2_hop(adj_1)
            adjs = (adj_0_1, adj_1, adj_0_1_2)
        else: 
            raise ValueError('Invalid NCN Mode! Mode must be 0, 1, or 2.')

        # Get updated memory of all nodes involved in the computation.
        z, last_update = self.memory(n_id)
        
        z = self.gnn(
            z,
            last_update,
            n_edge_index,
            edge_weight,
            data.t[n_e_id].to(device),
            data.msg[n_e_id].to(device),
        )
        
        y_pred = self.link_pred(z, adjs, torch.stack([assoc[src],assoc[dst]]), self.ncn_mode)
        return y_pred     

    
    def train(self):
        self.memory.train()
        self.gnn.train()
        self.link_pred.train()
    
    def eval(self):
        self.memory.eval()
        self.gnn.eval()
        self.link_pred.eval()

    def reset_parameters(self):
        self.memory.reset_parameters()
        self.gnn.reset_parameters()
        self.link_pred.reset_parameters()

    def insert_neighbor(self, src, dst):
        self.neighbor_loader.insert(src, dst)


    def update_memory(self, src, dst, t, msg, edge_weight=None):
        if edge_weight is None:
            edge_weight = torch.ones(len(src), device=self.device)
        self.memory.update_state(src, dst, t, msg, edge_weight)


    def get_current_edge_index(self):
        return self.neighbor_loader.cur_e_id

    @staticmethod
    def generate_adj_1_hop(id_num, edge_index, device):
        loop_edge = torch.arange(id_num, dtype=torch.int64, device=device)
        mask = ~ torch.isin(loop_edge, edge_index)
        loop_edge = loop_edge[mask]
        loop_edge = torch.stack([loop_edge,loop_edge])
        if edge_index.size(1) == 0:
            adj = SparseTensor.from_edge_index(loop_edge).to_device(device)
        else:
            adj = SparseTensor.from_edge_index(torch.cat((loop_edge, edge_index, torch.stack([edge_index[1], edge_index[0]])),dim=-1)).to_device(device)
            # adj = SparseTensor.from_edge_index(edge_index).to_device(device)
        return adj
    
    @staticmethod
    def generate_adj_0_1_hop(id_num, edge_index, device):
        loop_edge = torch.arange(id_num, dtype=torch.int64, device=device)
        loop_edge = torch.stack([loop_edge,loop_edge])
        if edge_index.size(1) == 0:
            adj = SparseTensor.from_edge_index(loop_edge).to_device(device)
        else:
            adj = SparseTensor.from_edge_index(torch.cat((loop_edge, edge_index, torch.stack([edge_index[1], edge_index[0]])),dim=-1)).to_device(device)
        return adj
    
    @staticmethod
    def generate_adj_0_1_2_hop(adj):
        # adj = SparseTensor.to_dense(adj)
        # adj = torch.mm(adj, adj)
        # adj = SparseTensor.from_dense(adj)
        adj = adj.matmul(adj)
        return adj

    @staticmethod
    def find_neighbor(neighbor_loader:LastNeighborLoader, n_id, k=1):
        for i in range(k-1):
            n_id, _, _ = neighbor_loader(n_id)
        neighbor_info = neighbor_loader(n_id)
        return neighbor_info

