import torch
import torch.nn as nn
from .mlp import MLPLayer, LabelerMLP
from .mpgnn_batch import MPLayer


class MPGNN(nn.Module):
    """GraphSAGE/GIN model.
    
    Args:
        layers: (list) Each element is a (in_dim, out_dim) tuple.
        mlp_layers: (list) Each element is a (in_dim, out_dim) tuple.
    """

    def __init__(self, layers, mlp_layers, dropout=0.5):
        super(MPGNN, self).__init__()
        self.dropout = dropout
        self.layers = nn.ModuleList()
        for idim, hdim, odim in layers:
            self.layers.append(MPLayer(idim, hdim, odim))
        self.mlp_layers = nn.ModuleList()
        for idim, hdim, odim in mlp_layers:
            self.mlp_layers.append(MLPLayer(idim, hdim, odim))
         
    def forward(self, adj, X):
        hiddens = [torch.sum(X, axis=0)]
        h = X
        for layer in self.layers:
            h = layer(adj, h)
            hiddens.append(torch.sum(h, axis=0))
        output = 0
        for h, mlp_layer in zip(hiddens, self.mlp_layers):
            output += torch.dropout(mlp_layer(h), self.dropout, train=self.training)
        return output


class LabelerMPLayer(nn.Module):
    """A message-passing layer.

    Args:
        idim: (int) Input dimension.
        hdim: (int) Hidden dimension.
        odim: (int) Output dimension.
    """
    def __init__(self, idim, hdim, odim):
        super(LabelerMPLayer, self).__init__()
        self.fci = nn.Linear(idim, hdim)
        self.fco = nn.Linear(hdim, odim)

    def forward(self, block_adj, X):
        """Forward function.

        Args:
            block_adj: (nx.Graph) Input batch block diag adj.
            X: (torch.Tensor) Batched node features.
        """
        propX = torch.spmm(block_adj, X)
        h = torch.relu(self.fci(propX))
        return self.fco(h)


class LabelerMPGNN(nn.Module):
    """GraphSAGE/GIN model.
    
    Args:
        layers: (list) Each element is a (in_dim, out_dim) tuple.
        mlp_layers: (list) Each element is a (in_dim, out_dim) tuple.
    """

    def __init__(self, layers, mlp_layers):
        super(LabelerMPGNN, self).__init__()
        self.layers = nn.ModuleList()
        for idim, hdim, odim in layers:
            self.layers.append(LabelerMPLayer(idim, hdim, odim))
        self.mlp_layers = nn.ModuleList()
        for idim, hdim, odim in mlp_layers:
            self.mlp_layers.append(LabelerMLP(idim, hdim, odim))
         
    def forward(self, adj, X):
        hiddens = [torch.sum(X, axis=0)]
        h = X
        for layer in self.layers:
            h = layer(adj, h)
            hiddens.append(torch.sum(h, axis=0))
        output = 0
        for h, mlp_layer in zip(hiddens, self.mlp_layers):
            output += mlp_layer(h)
        return output
