import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data, edge_weight=None, pooling=False):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if edge_weight is None:
            edge_weight = data.edge_weight

        x1 = F.relu(self.conv1(x, edge_index, edge_weight))
        x2 = self.conv2(x1, edge_index, edge_weight)

        if pooling == 'mean':
            return global_mean_pool(x2, batch.long())
        if pooling == 'target':
            return x2[data.target_node_index]

        return x2