from typing import Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

from ..utils.common import GraphData

class GraphSAGE_IDS(nn.Module):
    """GraphSAGE baseline for intrusion detection."""

    def __init__(self, in_dim_node: int, hidden: int = 64, 
                 layers: int = 2, dropout: float = 0.1, num_classes: int = 2):
        super().__init__()

        self.layers = nn.ModuleList()
        self.dropouts = nn.ModuleList()

        # First layer
        self.layers.append(SAGEConv(in_dim_node, hidden))
        self.dropouts.append(nn.Dropout(dropout))

        # Hidden layers (if layers > 1)
        for _ in range(layers - 1):
            self.layers.append(SAGEConv(hidden, hidden))
            self.dropouts.append(nn.Dropout(dropout))

        # Output layer
        self.classifier = nn.Linear(hidden, num_classes)

    def forward(self, data: GraphData) -> Dict[str, Any]:
        x = data.x
        edge_index = data.edge_index

        for i, (layer, dropout) in enumerate(zip(self.layers, self.dropouts)):
            x = layer(x, edge_index)
            # Apply ReLU and Dropout to all layers except the last one before the classifier
            if i < len(self.layers) -1:
                x = F.relu(x)
                x = dropout(x)

        logits = self.classifier(x)
        return {"node_logits": logits, "edge_attn": None, "node_emb": x}
