import torch
import torch.nn as nn
import torch.nn.functional as F
from ..._layers import GIN_layer
from ..._utils import debug_print

global_debug_ckpt = '_GIN_lite'


class GIN_lite(nn.Module):
    def __init__(self, num_layers: int, num_mlp_layers: int, input_dim: int, hidden_dim: int, output_dim: int,
                 *args, **kwargs):
        """
        Inputs:
            num_layers:     [int] number of layers in the neural networks (INCLUDING the input layer)
            num_mlp_layers: [int] number of layers in mlps (EXCLUDING the input layer)
            input_dim:      [int] dimensionality of input features
            hidden_dim:     [int] dimensionality of hidden units at ALL layers
            output_dim:     [int] number of classes for prediction
            **kwargs
                final_dropout: dropout ratio on the final linear layer
                neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
                graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
                device: which device to use
                add_input_score: if scores from input layer will be added to final scores
        """
        super(GIN_lite, self).__init__()

        self.num_layers = num_layers
        self.num_mlp_layers = num_mlp_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.tau = 1.

        self.final_dropout = kwargs['final_dropout'] if 'final_dropout' in kwargs else 0.5
        self.graph_pooling_type = kwargs['graph_pooling_type'] if 'graph_pooling_type' in kwargs else 'sum'
        self.neighbor_pooling_type = kwargs['neighbor_pooling_type'] if 'neighbor_pooling_type' in kwargs else 'sum'
        self.device = kwargs['device'] if 'device' in kwargs else 'cuda'
        self.add_input_score = kwargs['add_input_score'] if 'add_input_score' in kwargs else True

        self.gins = nn.ModuleList()
        for layer in range(self.num_layers - 1):
            if layer == 0:
                self.gins.append(GIN_layer(self.num_mlp_layers, self.input_dim, self.hidden_dim, self.hidden_dim, is_batch_norm=True))
            else:
                self.gins.append(GIN_layer(self.num_mlp_layers, self.hidden_dim, self.hidden_dim, self.hidden_dim, is_batch_norm=True))

        self.predicts = nn.ModuleList()
        for layer in range(self.num_layers):
            if layer == 0:
                if self.add_input_score:
                    self.predicts.append(nn.Linear(self.input_dim, self.output_dim))
                else:
                    pass
            else:
                self.predicts.append(nn.Linear(self.hidden_dim, self.output_dim))
        
        self.debug_ckpt = global_debug_ckpt + '.GIN_lite'


    def _create_batch_pool_sumave(self, num_nodes_):
        """
        create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes)
        inputs:
            num_nodes_: [np.ndarray] number of nodes in each graph (in the batch)
        outputs:
            batch_pool: [torch.tensor] a block diagonal pooling mat, size = (num_graphs, tot_nodes), a sparse torch tensor
        """
        local_debug_ckpt = self.debug_ckpt + '._create_batch_pool_sumave'

        pooling_op_ = []
        for Nnodes in num_nodes_:
            pooling_op = torch.ones(Nnodes).to(self.device)
            if self.graph_pooling_type == 'average':
                pooling_op = pooling_op * (1 / Nnodes)
            pooling_op_.append(pooling_op)
        return torch.block_diag(*pooling_op_).to_sparse()


    def forward(self, batch_adj, batch_fts, num_nodes_):
        """
        Inputs:
            batch_adj   : [torch.tensor] a sparse binary adj, arranging graphs in the batch on the main diagonal
            batch_fts   : [torch.tensor] all node features in the batch
            num_nodes_  : [np.ndarray] number of nodes in each graph (in the batch)
            add_input_score : [boolean] if scores from input features will be added to final score
        Outputs:
            [torch.tensor]: prediction of graph categories, size = (num_graphs, dim_class)
        """
        local_debug_ckpt = self.debug_ckpt + '.forward'

        if self.neighbor_pooling_type == 'average':
            batch_adj = torch.sparse.softmax(batch_adj, dim=1)
        else:
            pass
        self.device = batch_adj.device

        h = batch_fts
        h_ = [h] if self.add_input_score else []
        for layer in range(self.num_layers - 1):
            h = self.gins[layer](batch_adj, h)
            h_.append(h)

        batch_pool = self._create_batch_pool_sumave(num_nodes_)
        score = 0
        for layer, h in enumerate(h_):
            h_g = torch.sparse.mm(batch_pool, h)
            score += F.dropout(self.predicts[layer](h_g), self.final_dropout, training=self.training)

        return score, h


