from typing import Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..utils.common import GraphData

class MLP_IDS(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 64, num_classes: int = 2, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_classes)
        )

    def forward(self, data: GraphData) -> Dict[str, Any]:
        logits = self.net(data.x)
        return {"node_logits": logits, "edge_attn": None, "node_emb": logits}
