import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP_layer(nn.Module):
    def __init__(self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int, *args, **kwargs):
        '''
        Inputs:
            num_layers:     [int] number of layers in mlps (EXCLUDING the input layer)
            input_dim:      [int] dimensionality of input features
            hidden_dim:     [int] dimensionality of hidden units at ALL layers
            output_dim:     [int] number of classes for prediction
            **kwargs
                final_dropout: dropout ratio on the final linear layer
                learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether.
                neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
                graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
                device: which device to use
                is_batch_norm: [bool] whether to apply batchnorm layer on the output of the network
        '''
        super(MLP_layer, self).__init__()

        assert num_layers >= 1, 'Value Error: num_layers in MLP_layer should be larger than 1'
        self.is_batch_norm = kwargs['is_batch_norm'] if 'is_batch_norm' in kwargs else False

        self.num_layers = num_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.linears = nn.ModuleList()
        if self.num_layers == 1:
            self.linears.append(nn.Linear(self.input_dim, self.output_dim))
        else:
            self.linears.append(nn.Linear(self.input_dim, self.hidden_dim))
            for layer in range(self.num_layers - 2):
                self.linears.append(nn.Linear(self.hidden_dim, self.hidden_dim))
            self.linears.append(nn.Linear(self.hidden_dim, self.output_dim))

        self.batch_norms = nn.ModuleList()
        if self.is_batch_norm:
            for layer in range(self.num_layers - 1):
                self.batch_norms.append(nn.BatchNorm1d(self.hidden_dim))

    def forward(self, node_features):
        '''
        Inputs:
            node_features: [torch.tensor] N*input_dim matrix
        Outputs:
            node_features: [torch.tensor] N*output_dim matrix
        '''
        h = node_features
        if self.is_batch_norm:
            for layer in range(self.num_layers - 1):
                h = F.relu(self.batch_norms[layer](self.linears[layer](h)))
        else:
            for layer in range(self.num_layers - 1):
                h = F.relu(self.linears[layer](h))

        return self.linears[self.num_layers - 1](h).to(node_features.device)  # return final output






