"""
How Powerful are Graph Neural Networks
https://arxiv.org/abs/1810.00826
https://openreview.net/forum?id=ryGs6iA5Km
Author's implementation: https://github.com/weihua916/powerful-gnns
"""

import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn

from dgl.nn.mxnet.conv import GINConv
from dgl.nn.mxnet.glob import AvgPooling, MaxPooling, SumPooling


class ApplyNodeFunc(nn.Block):
    """Update the node feature hv with MLP, BN and ReLU."""

    def __init__(self, mlp):
        super(ApplyNodeFunc, self).__init__()
        with self.name_scope():
            self.mlp = mlp
            self.bn = nn.BatchNorm(in_channels=self.mlp.output_dim)

    def forward(self, h):
        h = self.mlp(h)
        h = self.bn(h)
        h = nd.relu(h)
        return h


class MLP(nn.Block):
    """MLP with linear output"""

    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
        """MLP layers construction

        Paramters
        ---------
        num_layers: int
            The number of linear layers
        input_dim: int
            The dimensionality of input features
        hidden_dim: int
            The dimensionality of hidden units at ALL layers
        output_dim: int
            The number of classes for prediction
        """
        super(MLP, self).__init__()
        self.linear_or_not = True
        self.num_layers = num_layers
        self.output_dim = output_dim

        with self.name_scope():
            if num_layers < 1:
                raise ValueError("number of layers should be positive!")
            elif num_layers == 1:
                # Linear model
                self.linear = nn.Dense(output_dim, in_units=input_dim)
            else:
                self.linear_or_not = False
                self.linears = nn.Sequential()
                self.batch_norms = nn.Sequential()

                self.linears.add(nn.Dense(hidden_dim, in_units=input_dim))
                for layer in range(num_layers - 2):
                    self.linears.add(nn.Dense(hidden_dim, in_units=hidden_dim))
                self.linears.add(nn.Dense(output_dim, in_units=hidden_dim))

                for layer in range(num_layers - 1):
                    self.batch_norms.add(nn.BatchNorm(in_channels=hidden_dim))

    def forward(self, x):
        if self.linear_or_not:
            return self.linear(x)
        else:
            h = x
            for i in range(self.num_layers - 1):
                h = nd.relu(self.batch_norms[i](self.linears[i](h)))
            return self.linears[-1](h)


class GIN(nn.Block):
    """GIN model"""

    def __init__(
        self,
        num_layers,
        num_mlp_layers,
        input_dim,
        hidden_dim,
        output_dim,
        final_dropout,
        learn_eps,
        graph_pooling_type,
        neighbor_pooling_type,
    ):
        """model parameters setting

        Paramters
        ---------
        num_layers: int
            The number of linear layers in the neural network
        num_mlp_layers: int
            The number of linear layers in mlps
        input_dim: int
            The dimensionality of input features
        hidden_dim: int
            The dimensionality of hidden units at ALL layers
        output_dim: int
            The number of classes for prediction
        final_dropout: float
            dropout ratio on the final linear layer
        learn_eps: boolean
            If True, learn epsilon to distinguish center nodes from neighbors
            If False, aggregate neighbors and center nodes altogether.
        neighbor_pooling_type: str
            how to aggregate neighbors (sum, mean, or max)
        graph_pooling_type: str
            how to aggregate entire nodes in a graph (sum, mean or max)

        """
        super(GIN, self).__init__()
        self.num_layers = num_layers
        self.learn_eps = learn_eps

        with self.name_scope():
            # List of MLPs
            self.ginlayers = nn.Sequential()
            self.batch_norms = nn.Sequential()

            for i in range(self.num_layers - 1):
                if i == 0:
                    mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)
                else:
                    mlp = MLP(
                        num_mlp_layers, hidden_dim, hidden_dim, hidden_dim
                    )

                self.ginlayers.add(
                    GINConv(
                        ApplyNodeFunc(mlp),
                        neighbor_pooling_type,
                        0,
                        self.learn_eps,
                    )
                )
                self.batch_norms.add(nn.BatchNorm(in_channels=hidden_dim))

            self.linears_prediction = nn.Sequential()

            for i in range(num_layers):
                if i == 0:
                    self.linears_prediction.add(
                        nn.Dense(output_dim, in_units=input_dim)
                    )
                else:
                    self.linears_prediction.add(
                        nn.Dense(output_dim, in_units=hidden_dim)
                    )

            self.drop = nn.Dropout(final_dropout)

            if graph_pooling_type == "sum":
                self.pool = SumPooling()
            elif graph_pooling_type == "mean":
                self.pool = AvgPooling()
            elif graph_pooling_type == "max":
                self.pool = MaxPooling()
            else:
                raise NotImplementedError

    def forward(self, g, h):
        hidden_rep = [h]

        for i in range(self.num_layers - 1):
            h = self.ginlayers[i](g, h)
            h = self.batch_norms[i](h)
            h = nd.relu(h)
            hidden_rep.append(h)

        score_over_layer = 0
        # perform pooling over all nodes in each graph in every layer
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            score_over_layer = score_over_layer + self.drop(
                self.linears_prediction[i](pooled_h)
            )

        return score_over_layer
