import torch
import torch.nn as nn
import torch.nn.functional as F
from ..._layers import GIN_layer

class GIN(nn.Module):
    def __init__(self, num_layers: int, num_mlp_layers: int, input_dim: int, hidden_dim: int, output_dim: int,
                 *args, **kwargs):
        """
        Inputs:
            num_layers:     [int] number of layers in the neural networks (INCLUDING the input layer)
            num_mlp_layers: [int] number of layers in mlps (EXCLUDING the input layer)
            input_dim:      [int] dimensionality of input features
            hidden_dim:     [int] dimensionality of hidden units at ALL layers
            output_dim:     [int] number of classes for prediction
            **kwargs
                final_dropout: dropout ratio on the final linear layer
                learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether.
                neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
                graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
                device: which device to use
        """
        super(GIN, self).__init__()

        self.num_layers = num_layers
        self.num_mlp_layers = num_mlp_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.final_dropout = kwargs['final_dropout'] if 'final_dropout' in kwargs else 0.5
        self.learn_eps = kwargs['learn_eps'] if 'learn_eps' in kwargs else False
        self.graph_pooling_type = kwargs['graph_pooling_type'] if 'graph_pooling_type' in kwargs else 'sum'
        self.neighbor_pooling_type = kwargs['neighbor_pooling_type'] if 'neighbor_pooling_type' in kwargs else 'sum'
        self.device = kwargs['device'] if 'device' in kwargs else 'cuda'

        self.gins = nn.ModuleList()
        for layer in range(self.num_layers - 1):
            if layer == 0:
                self.gins.append(GIN_layer(self.num_mlp_layers, self.input_dim, self.hidden_dim, self.hidden_dim, is_batch_norm=True))
            else:
                self.gins.append(GIN_layer(self.num_mlp_layers, self.hidden_dim, self.hidden_dim, self.hidden_dim, is_batch_norm=True))

        self.predicts = nn.ModuleList()
        for layer in range(self.num_layers):
            if layer == 0:
                self.predicts.append(nn.Linear(self.input_dim, self.output_dim))
            else:
                self.predicts.append(nn.Linear(self.hidden_dim, self.output_dim))

    def _create_batch_adj_sumave(self, batch_graph):
        """
        create block diagonal sparse matrix
        Inputs:
            batch_graph [list]: a list of graph
        Outputs:
            batch_adj   [torch.tensor]:
        """
        batch_adj_edges_list = []
        node_start_idx = [0]
        for i, graph in enumerate(batch_graph):
            node_start_idx.append(node_start_idx[i] + len(graph.g))
            batch_adj_edges_list.append(graph.edge_mat + node_start_idx[i])  # num_nodes*2
        batch_adj_edges = torch.cat(batch_adj_edges_list, 1)
        batch_adj_elems = torch.ones(batch_adj_edges.shape[1])
        batch_num_node = node_start_idx[-1]

        # add self-loops
        if not self.learn_eps:
            self_loop_edges = torch.LongTensor([range(batch_num_node), range(batch_num_node)])
            self_loop_elems = torch.ones(batch_num_node)
            batch_adj_edges = torch.cat([batch_adj_edges, self_loop_edges], 1)
            batch_adj_elems = torch.cat([batch_adj_elems, self_loop_elems], 0)
        batch_adj = torch.sparse.FloatTensor(batch_adj_edges, batch_adj_elems, torch.Size([batch_num_node, batch_num_node]))

        return batch_adj.to(self.device)

    def _create_batch_pool_sumave(self, batch_graph):
        """
        create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes)
        inputs:
            batch_graph: [list] a list of graphs
        outputs:
            batch_pool: [torch.tensor] num_graphs*batch_num_nodes, a sparse torch tensor
            batch_pool: [torch.tensor] num_graphs*batch_num_nodes, a sparse torch tensor
        """
        node_start_idx = [0]
        for i, graph in enumerate(batch_graph):
            node_start_idx.append(node_start_idx[i] + len(graph.g))
        batch_num_node = node_start_idx[-1]

        batch_pool_edges = []
        batch_pool_elems = []
        for i, graph in enumerate(batch_graph):
            batch_pool_edges.extend([[i, j] for j in range(node_start_idx[i], node_start_idx[i + 1], 1)])
            if self.graph_pooling_type == "average":
                batch_pool_elems.extend([1. / len(graph.g)] * len(graph.g))
            else:
                batch_pool_elems.extend([1] * len(graph.g))


        batch_pool_elems = torch.FloatTensor(batch_pool_elems)
        batch_pool_edges = torch.LongTensor(batch_pool_edges).transpose(0, 1)
        graph_pool = torch.sparse.FloatTensor(batch_pool_edges, batch_pool_elems, torch.Size([len(batch_graph), batch_num_node]))

        return graph_pool.to(self.device)

    def forward(self, batch_graph):
        """
        Inputs:
            batch_graph [list]: a list of graph
        Outputs:
            batch_predict [torch.tensor]: num_graphs* dim_class
        """
        batch_node_features = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device)
        if self.neighbor_pooling_type == 'max':
            pass
        else:
            batch_adj = self._create_batch_adj_sumave(batch_graph)

        h_trans = batch_node_features
        h_trans_list = [batch_node_features]
        for layer in range(self.num_layers - 1):
            h_trans = self.gins[layer](batch_adj, h_trans)
            h_trans_list.append(h_trans)

        batch_pool = self._create_batch_pool_sumave(batch_graph)
        batch_predict = 0
        for layer, h_trans in enumerate(h_trans_list):
            batch_graph_h = torch.spmm(batch_pool, h_trans)
            batch_predict += F.dropout(self.predicts[layer](batch_graph_h), self.final_dropout, training=self.training)

        return batch_predict

