import torch
import torch_geometric
from torch import nn
from torch.nn import functional as F

from torch_geometric.nn import SAGEConv, global_max_pool

import torch
import torch.nn.functional as F


class GraphSAGE(nn.Module):
    def __init__(self, dim_features, dim_target, config):
        super().__init__()

        num_layers = config['num_layers']
        dim_embedding = config['dim_embedding']
        self.aggregation = config['aggregation']  # can be mean or max
        self.last_layer_fa = config['last_layer_fa']
        if self.last_layer_fa:
            print('Using LastLayerFA')

        if self.aggregation == 'max':
            self.fc_max = nn.Linear(dim_embedding, dim_embedding)

        self.layers = nn.ModuleList([])
        for i in range(num_layers):
            dim_input = dim_features if i == 0 else dim_embedding

            conv = SAGEConv(dim_input, dim_embedding)
            # Overwrite aggregation method (default is set to mean
            conv.aggr = self.aggregation

            self.layers.append(conv)

        # For graph classification
        self.fc1 = nn.Linear(num_layers * dim_embedding, dim_embedding)
        self.fc2 = nn.Linear(dim_embedding, dim_target)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x_all = []

        for i, layer in enumerate(self.layers):
            edges = edge_index
            if self.last_layer_fa and i == len(self.layers) - 1:
                block_map = torch.eq(batch.unsqueeze(0), batch.unsqueeze(-1)).int()
                edges, _ = torch_geometric.utils.dense_to_sparse(block_map)
            x = layer(x, edges)
            if self.aggregation == 'max':
                x = torch.relu(self.fc_max(x))
            x_all.append(x)

        x = torch.cat(x_all, dim=1)
        x = global_max_pool(x, batch)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
