import torch
import torch.nn as nn
from bsgnn.data import process_batch
from .mlp import MLPLayer
from .mpgnn_batch import MPGNN 


class KMPGNN(nn.Module):
    """Ensemble of K Message-Passing models.
    
    Args:
        layers: (list) Each element is a (in_dim, hid_dim, out_dim) tuple.
        mlp_layers: (list) Each element is a (in_dim, hid_dim, out_dim) tuple.
    """
    def __init__(self, K, layers, mlp_layers, dropout=0.5):
        super(KMPGNN, self).__init__()
        self.dropout = dropout
        self.K = K
        self.predictors = nn.ModuleList()
        for i in range(K):
            self.predictors.append(MPGNN(layers, mlp_layers, dropout))

    def forward(self, blk_adjs, batch_Xs, graphs_sizes):
        """Compute results for one batch.

        Args:
            batch_adjs: (torch.SparseTensor) Input graphs.
            batch_X: (torch.Tensor) Node features.
            graph_sizes: (list) Size of each graph.
        """
        output = 0
        for p, blka, bX, gs in zip(self.predictors, blk_adjs, batch_Xs, graphs_sizes):
            output += p(blka, bX, gs)
        return (1/self.K) * output

