
from __future__ import annotations
from typing import List, Tuple, Optional
import torch, torch.nn as nn

class TabMLP(nn.Module):
    def __init__(self, in_features: int, hidden: List[int], num_classes: int, dropout: float = 0.0, batch_norm: bool = False):
        super().__init__()
        self.fc_in = nn.Linear(in_features, hidden[0])
        self.bn_in = nn.BatchNorm1d(hidden[0]) if batch_norm else None
        self.act_in = nn.ReLU()
        self.drop_in = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.hidden_layers = nn.ModuleList()
        prev = hidden[0]
        for h in hidden[1:]:
            self.hidden_layers.append(nn.Linear(prev, h))
            if batch_norm:
                self.hidden_layers.append(nn.BatchNorm1d(h))
            self.hidden_layers.append(nn.ReLU())
            if dropout > 0:
                self.hidden_layers.append(nn.Dropout(dropout))
            prev = h
        self.fc_out = nn.Linear(prev, num_classes)

    def forward(self, x):
        x = self.fc_in(x)
        if self.bn_in is not None:
            x = self.bn_in(x)
        x = self.act_in(x)
        x = self.drop_in(x)
        for m in self.hidden_layers:
            x = m(x)
        return self.fc_out(x)

def available_tabular_models() -> List[str]:
    return ["mlp_small", "mlp_medium", "mlp_complex"]

def choose_tabular_model(key: str, num_features: int, num_classes: int, device: Optional[torch.device] = None):
    key = key.strip().lower()
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu"))
    dropout = 0.0
    batch_norm = False
    if key == "mlp_small":
        hidden = [64, 64]
    elif key == "mlp_medium":
        hidden = [128, 64, 64]
    elif key == "mlp_complex":
        hidden = [512, 256, 256, 128, 128, 64]
        dropout = 0.3
        batch_norm = True
    else:
        raise ValueError(f"Unknown tabular model key: {key}")
    model = TabMLP(num_features, hidden, num_classes, dropout=dropout, batch_norm=batch_norm).to(device).eval()
    # Layer names available for hooks
    layer_names = ["fc_in"]
    # enumerate hidden linear layers with indices 0,2,4,... (since ReLUs in between)
    for idx, m in enumerate(model.hidden_layers):
        if isinstance(m, torch.nn.Linear):
            layer_names.append(f"hidden_layers.{idx}")
    return model, layer_names, device, hidden
