import torch
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from utils.config import cfg

from models.conv import GNN_node

class GNN(torch.nn.Module):

    def __init__(self):
        super(GNN, self).__init__()

        if cfg.gnn.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings - removed virtual nodes
        self.gnn_node = GNN_node()

        ### Pooling function to generate whole-graph embeddings
        if cfg.gnn.pool == "sum":
            self.pool = global_add_pool
        elif cfg.gnn.pool == "mean":
            self.pool = global_mean_pool
        elif cfg.gnn.pool == "max":
            self.pool = global_max_pool
        # elif self.graph_pooling == "attention":
        #     self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        # elif self.graph_pooling == "set2set":
        #     self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_linear = torch.nn.Linear(cfg.gnn.emb_dim, cfg.gnn.output_dim)

    def forward(self, batched_data):
        h_node, batch_indicator = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batch_indicator)

        return self.graph_pred_linear(h_graph)

if __name__ == '__main__':
    GNN()
