from torch import nn
from torch_geometric.nn import (
    GCNConv, SAGEConv, GATConv, GINConv,
    global_mean_pool, BatchNorm
)
import torch.nn.functional as F

class GCNNet(nn.Module):
    def __init__(self, in_dim, hid_dim, num_layers, dropout):
        super().__init__()
        self.convs = nn.ModuleList([GCNConv(in_dim, hid_dim)] + [GCNConv(hid_dim, hid_dim) for _ in range(num_layers - 1)])
        self.lin = nn.Linear(hid_dim, 2)
        self.dropout = dropout
        
    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = global_mean_pool(x, batch)
        return self.lin(x)
class GraphSAGENet(nn.Module):
    def __init__(self, in_dim, hid_dim, num_layers, dropout):
        super().__init__()
        self.convs = nn.ModuleList([SAGEConv(in_dim, hid_dim)] + [SAGEConv(hid_dim, hid_dim) for _ in range(num_layers - 1)])
        self.lin = nn.Linear(hid_dim, 2)
        self.dropout = dropout

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = global_mean_pool(x, batch)
        return self.lin(x)
class GATNet(nn.Module):
    def __init__(self, in_dim, hid_dim, num_layers, dropout, heads=1):
        super().__init__()
        self.convs = nn.ModuleList()
        # First layer
        self.convs.append(GATConv(in_dim, hid_dim, heads=heads, dropout=dropout))
        # Subsequent layers: input dim is hid_dim*heads due to concatenation
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hid_dim * heads, hid_dim, heads=heads, dropout=dropout))
        self.lin = nn.Linear(hid_dim * heads, 2)
        self.dropout = dropout

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = global_mean_pool(x, batch)
        return self.lin(x)
class GINNet(nn.Module):
    def __init__(self, in_dim, hid_dim, num_layers, dropout):
        super().__init__()
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            mlp = nn.Sequential(
                nn.Linear(in_dim if i == 0 else hid_dim, hid_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hid_dim, hid_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
            self.convs.append(GINConv(mlp))
        self.lin = nn.Linear(hid_dim, 2)
        self.dropout = dropout

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = global_mean_pool(x, batch)
        return self.lin(x)