import torch
import torch.nn as nn
import torch.nn.functional as F
from ._MLP_layer import MLP_layer

class GIN_layer(nn.Module):
    def __init__(self, num_mlp_layers: int, input_dim: int, hidden_dim: int, output_dim: int,
                 *args, **kwargs):
        """
        Inputs:
            num_mlp_layers: [int] number of layers in mlps (EXCLUDING the input layer)
            input_dim:      [int] dimensionality of input features
            hidden_dim:     [int] dimensionality of hidden units at ALL layers
            output_dim:     [int] number of classes for prediction
            **kwargs
                final_dropout: dropout ratio on the final linear layer
                learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether.
                neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
                graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
                device: which device to use
                is_batch_norm: [bool] whether to apply batchnorm layer on the output of the network

        """
        super(GIN_layer, self).__init__()

        self.num_mlp_layers = num_mlp_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.neighbor_pooling_type = kwargs['neighbor_pooling_type'] if 'neighbor_pooling_type' in kwargs else 'sum'
        self.is_batch_norm = kwargs['is_batch_norm'] if 'is_batch_norm' in kwargs else False
        self.device = kwargs['device'] if 'device' in kwargs else 'cuda'

        # create a mlp
        self.mlp = MLP_layer(self.num_mlp_layers, self.input_dim, self.hidden_dim, self.output_dim, is_batch_norm=True)
        # create batch norm
        self.batch_norm = nn.BatchNorm1d(self.output_dim)

    def forward(self, batch_adj, batch_node_features):
        '''
        Inputs:
            batch_adj: [troch.tensor] batch_num_nodes * batch_num_nodes
            batch_node_features: [torch.tensor] batch_num_nodes*input_dim matrix
        Outputs:
            batch_node_features: [torch.tensor] batch_num_nodes*output_dim matrix
        '''
        # aggregation
        h_aggre = torch.sparse.mm(batch_adj, batch_node_features)
        if self.neighbor_pooling_type == 'average':
            batch_node_degree = torch.sparse.mm(batch_adj, torch.ones([batch_adj.shape[0], 1]))
            h_aggre = h_aggre / batch_node_degree

        # transform
        h_trans = self.mlp(h_aggre)
        if self.is_batch_norm:
            h_trans = self.batch_norm(h_trans)

        return F.relu(h_trans).to(batch_node_features.device)
