from typing import Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv

from ..utils.common import GraphData

class GIN_IDS(nn.Module):
    """Graph Isomorphism Network baseline for intrusion detection."""

    def __init__(self, in_dim_node: int, hidden: int = 64, 
                 layers: int = 2, dropout: float = 0.1, num_classes: int = 2):
        super().__init__()

        self.layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        # Create MLP for each GIN layer
        for i in range(layers):
            if i == 0:
                mlp = nn.Sequential(
                    nn.Linear(in_dim_node, hidden),
                    nn.ReLU(),
                    nn.Linear(hidden, hidden)
                )
            else:
                mlp = nn.Sequential(
                    nn.Linear(hidden, hidden),
                    nn.ReLU(),
                    nn.Linear(hidden, hidden)
                )

            self.layers.append(GINConv(mlp))
            self.batch_norms.append(nn.BatchNorm1d(hidden))

        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden, num_classes)

    def forward(self, data: GraphData) -> Dict[str, Any]:
        x = data.x
        edge_index = data.edge_index

        for layer, bn in zip(self.layers, self.batch_norms):
            x = layer(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)

        logits = self.classifier(x)
        return {"node_logits": logits, "edge_attn": None, "node_emb": x}
