import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import sys
from ginpool.mlp import MLP

sys.path.append("..")
from src.global_pooling_layers import TemporalGlobalPoolingLayer

class GraphCNN(nn.Module):
    def __init__(self, args):
        '''
            args: arguments
        '''

        super(GraphCNN, self).__init__()

        self.final_dropout = args.final_dropout
        self.num_layers = args.num_layers
        self.graph_pooling_type = args.graph_pooling_type
        self.neighbor_pooling_type = args.neighbor_pooling_type
        self.learn_eps = args.learn_eps
        
        self.eps = nn.Parameter(torch.zeros(self.num_layers - 1))

        ###List of MLPs
        self.mlps = torch.nn.ModuleList()

        ###List of batchnorms applied to the output of MLP (input of the final prediction linear layer)
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(self.num_layers - 1):
            if layer == 0:
                self.mlps.append(MLP(args.num_mlp_layers, args.input_dim, args.hidden_dim, args.hidden_dim))
            else:
                self.mlps.append(MLP(args.num_mlp_layers, args.hidden_dim, args.hidden_dim, args.hidden_dim))

            self.batch_norms.append(nn.BatchNorm1d(args.hidden_dim))

        # for bilinear mapping second-order pooling
        self.hidden_dim = args.hidden_dim
        self.special_pooling_type = args.special_pooling_type
        if self.special_pooling_type == "temporal":          
            args.in_dim = args.hidden_dim
            args.out_dim = args.hidden_dim
            self.global_pooling_layer = TemporalGlobalPoolingLayer(args)
        else:
            self.global_pooling_layer = None
        self.prediction = nn.Linear(args.hidden_dim, args.output_dim)
        self.linear_layer = nn.Linear(args.input_dim, args.hidden_dim)

    def __preprocess_neighbors_maxpool(self, batch_graph):
        ###create padded_neighbor_list in concatenated graph

        # compute the maximum number of neighbors within the graphs in the current minibatch
        max_deg = max([graph.max_neighbor for graph in batch_graph])

        padded_neighbor_list = []
        start_idx = [0]

        for i, graph in enumerate(batch_graph):
            start_idx.append(start_idx[i] + len(graph.g))
            padded_neighbors = []
            for j in range(len(graph.neighbors)):
                # add off-set values to the neighbor indices
                pad = [n + start_idx[i] for n in graph.neighbors[j]]
                # padding, dummy data is assumed to be stored in -1
                pad.extend([-1] * (max_deg - len(pad)))

                # Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
                if not self.learn_eps:
                    pad.append(j + start_idx[i])

                padded_neighbors.append(pad)
            padded_neighbor_list.extend(padded_neighbors)

        return torch.LongTensor(padded_neighbor_list)

    def __preprocess_neighbors_sumavepool(self, batch_graph):
        ###create block diagonal sparse matrix
        edge_mat_list = []
        start_idx = [0]
        for i, graph in enumerate(batch_graph):
            start_idx.append(start_idx[i] + len(graph.g))
            edge_mat_list.append(graph.edge_mat + start_idx[i])
        Adj_block_idx = torch.cat(edge_mat_list, 1)
        Adj_block_elem = torch.ones(Adj_block_idx.shape[1],device=Adj_block_idx.device)

        # Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
        if not self.learn_eps:
            num_node = start_idx[-1]
            self_loop_edge = torch.tensor([range(num_node), range(num_node)], dtype=torch.long, device=Adj_block_idx.device)
            elem = torch.ones(num_node,device=Adj_block_idx.device)
            Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1)
            Adj_block_elem = torch.cat([Adj_block_elem, elem], 0)

        Adj_block = torch.sparse_coo_tensor(
            indices=Adj_block_idx,
            values=Adj_block_elem,
            size=torch.Size([start_idx[-1], start_idx[-1]]),
            device=Adj_block_idx.device
        )

        return Adj_block


    def maxpool(self, h, padded_neighbor_list):
        ###Element-wise minimum will never affect max-pooling

        dummy = torch.min(h, dim=0)[0]
        h_with_dummy = torch.cat([h, dummy.reshape((1, -1))])
        pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim=1)[0]
        return pooled_rep

    def next_layer_eps(self, h, layer, padded_neighbor_list=None, Adj_block=None):
        ###pooling neighboring nodes and center nodes separately by epsilon reweighting.

        if self.neighbor_pooling_type == "max":
            ##If max pooling
            pooled = self.maxpool(h, padded_neighbor_list)
        else:
            # If sum or average pooling
            pooled = torch.spmm(Adj_block, h)
            if self.neighbor_pooling_type == "average":
                # If average pooling
                degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)))
                pooled = pooled / degree

        # Reweights the center node representation when aggregating it with its neighbors
        pooled = pooled + (1 + self.eps[layer]) * h
        pooled_rep = self.mlps[layer](pooled)
        h = self.batch_norms[layer](pooled_rep)

        # non-linearity
        h = F.relu(h)
        return h

    def next_layer(self, h, layer, padded_neighbor_list=None, Adj_block=None):
        ###pooling neighboring nodes and center nodes altogether

        if self.neighbor_pooling_type == "max":
            ##If max pooling
            pooled = self.maxpool(h, padded_neighbor_list)
        else:
            # If sum or average pooling
            pooled = torch.spmm(Adj_block, h)
            if self.neighbor_pooling_type == "average":
                # If average pooling
                degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)))
                pooled = pooled / degree

        # representation of neighboring and center nodes
        pooled_rep = self.mlps[layer](pooled)

        h = self.batch_norms[layer](pooled_rep)

        # non-linearity
        h = F.relu(h)
        return h


    def forward(self, batch_graph):
        X_concat = torch.cat([graph.node_features for graph in batch_graph], 0)
        if self.neighbor_pooling_type == "max":
            padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph)
        else:
            Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph)
        # Check if special graph pooling is 'supra' and prepare edge_index_list
        if self.special_pooling_type == "supra":
            edge_index_list = [graph.edge_mat for graph in batch_graph]
        
        # list of hidden representation at each layer (including input)
        hidden_rep = [self.linear_layer(X_concat)]
        h = X_concat

        for layer in range(self.num_layers - 1):
            if self.neighbor_pooling_type == "max" and self.learn_eps:
                h = self.next_layer_eps(h, layer, padded_neighbor_list=padded_neighbor_list)
            elif not self.neighbor_pooling_type == "max" and self.learn_eps:
                h = self.next_layer_eps(h, layer, Adj_block=Adj_block)
            elif self.neighbor_pooling_type == "max" and not self.learn_eps:
                h = self.next_layer(h, layer, padded_neighbor_list=padded_neighbor_list)
            elif not self.neighbor_pooling_type == "max" and not self.learn_eps:
                h = self.next_layer(h, layer, Adj_block=Adj_block)

            hidden_rep.append(h)

        hidden_rep = torch.stack(hidden_rep, 1) # N_totalxTxD

        graph_sizes = [graph.node_features.size()[0] for graph in batch_graph]
        node_embeddings = torch.split(hidden_rep, graph_sizes, dim=0)
        padded_node_embeddings = pad_sequence(node_embeddings, batch_first=True, padding_value=0) # shape is (B, N_max, T, D)
        
        lengths = torch.tensor([len(node_embeddings[i]) for i in range(len(node_embeddings))], device=hidden_rep.device)
        mask = torch.arange(max(lengths), device=lengths.device) < lengths.unsqueeze(1)  # [B, N_max]
        if self.special_pooling_type == "supra":
            graph_representation = self.global_pooling_layer(padded_node_embeddings, edge_index_list, mask)
        elif self.special_pooling_type == "temporal":
            graph_representation = self.global_pooling_layer(padded_node_embeddings, mask)
        elif self.special_pooling_type == "mean": 
            graph_representation = (padded_node_embeddings[:,:,-1] * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True)
        elif self.special_pooling_type == "max":
            graph_representation = (padded_node_embeddings[:,:,-1] * mask.unsqueeze(-1)).max(dim=1)[0]
        else:
            raise ValueError(f"Invalid special pooling type: {self.special_pooling_type}")

        score = F.dropout(self.prediction(graph_representation), self.final_dropout, training=self.training)

        return score
