import torch
import torch.nn as nn
from bsgnn.data import process_batch
from .mlp import MLPLayer


class MPLayer(nn.Module):
    """A message-passing layer.

    Args:
        idim: (int) Input dimension.
        hdim: (int) Hidden dimension.
        odim: (int) Output dimension.
    """
    def __init__(self, idim, hdim, odim):
        super(MPLayer, self).__init__()
        self.fci = nn.Linear(idim, hdim)
        self.fco = nn.Linear(hdim, odim)
        self.bn = nn.BatchNorm1d(hdim)

    def forward(self, block_adj, X):
        """Forward function.

        Args:
            block_adj: (nx.Graph) Input batch block diag adj.
            X: (torch.Tensor) Batched node features.
        """
        propX = torch.spmm(block_adj, X)
        h = torch.relu(self.bn(self.fci(propX)))
        return self.fco(h)


class MPGNN(nn.Module):
    """GraphSAGE/GIN model.
    
    Args:
        layers: (list) Each element is a (in_dim, out_dim) tuple.
        mlp_layers: (list) Each element is a (in_dim, out_dim) tuple.
    """
    def __init__(self, layers, mlp_layers, dropout=0.5):
        super(MPGNN, self).__init__()
        self.dropout = dropout
        self.layers = nn.ModuleList()
        for idim, hdim, odim in layers:
            self.layers.append(MPLayer(idim, hdim, odim))
        self.mlp_layers = nn.ModuleList()
        for idim, hdim, odim in mlp_layers:
            self.mlp_layers.append(MLPLayer(idim, hdim, odim))

    @staticmethod
    def batch_sum(batch_X, graph_sizes):
        """Sum over representations of nodes according to graphsize.

        Args:
            nodes_h: (torch.Tensor) Node features.
            graph_sizes: (list) Size of each graph.

        Note: 
            Sparse matmul implementation might be faster.
        """
        slice_sum = [torch.sum(t, axis=0)\
                        for t in torch.split(batch_X, graph_sizes)]
        return torch.stack(slice_sum)
         
    def forward(self, blk_adj, batch_X, graph_sizes):
        """Compute results for one batch.

        Args:
            batch_adjs: (torch.SparseTensor) Input graphs.
            batch_X: (torch.Tensor) Node features.
            graph_sizes: (list) Size of each graph.
        """
        hiddens = [MPGNN.batch_sum(batch_X, graph_sizes)]
        h = batch_X
        for layer in self.layers:
            h = layer(blk_adj, h)
            hiddens.append(MPGNN.batch_sum(h, graph_sizes))
        output = 0
        for h, mlp_layer in zip(hiddens, self.mlp_layers):
            output += torch.dropout(mlp_layer(h), self.dropout, train=self.training)
        return output
