import numpy as np
import torch
import torch.nn as nn
from torch import amp
from torch.utils.data import DataLoader, TensorDataset

class MLPBlock(nn.Module):
    def __init__(self, d, hidden, p_drop):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, hidden),
            nn.SiLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden, d),
            nn.Dropout(p_drop),
        )
    def forward(self, x):
        return x + self.net(x)  # pre-norm residual

class RealMLPBackbone(nn.Module):
    def __init__(self, d_in, d_model=512, n_blocks=6, hidden=1024, p_in=0.0, p_drop=0.1):
        super().__init__()
        self.inp = nn.Sequential(nn.LayerNorm(d_in), nn.Dropout(p_in), nn.Linear(d_in, d_model))
        self.blocks = nn.ModuleList([MLPBlock(d_model, hidden, p_drop) for _ in range(n_blocks)])
        self.head_norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = self.inp(x)
        for blk in self.blocks:
            x = blk(x)
        return self.head_norm(x)

class RealMLPClassifier(nn.Module):
    """
    Strong MLP baseline with pre-norm residual blocks.
    Scikit-style API: fit(X,y[,X_val,y_val]), predict_proba(X), predict(X).
    """
    def __init__(self, params, args):
        super().__init__()
        self.device = torch.device("cuda" if args.use_gpu else "cpu")
        self.n_features = args.num_features
        self.n_classes  = args.num_classes

        # hyperparams
        self.d_model = int(params.get("width", 512))
        self.n_blocks= int(params.get("depth", 6))
        self.hidden  = int(params.get("hidden", self.d_model*2))
        self.p_in    = float(params.get("input_dropout", 0.0))
        self.p_drop  = float(params.get("dropout", 0.1))
        self.epochs  = int(params.get("epochs", getattr(args, "epochs", 100)))
        self.bs      = int(params.get("batch_size", getattr(args, "batch_size", 256)))
        self.pat     = int(params.get("patience", 20))
        self.lr      = float(params.get("learning_rate", 1e-3))
        self.weight_decay = float(params.get("weight_decay", 1e-5))
        self.label_smoothing = float(params.get("label_smoothing", 0.0))
        self.amp      = bool(params.get("amp", True))

        self.backbone = RealMLPBackbone(self.n_features, self.d_model, self.n_blocks, self.hidden, self.p_in, self.p_drop).to(self.device)
        self.head = nn.Linear(self.d_model, self.n_classes).to(self.device)

    def _make_loaders(self, X, y, Xv=None, yv=None):
        Xtr = torch.from_numpy(X).float()
        ytr = torch.from_numpy(y).long()
        tr = DataLoader(TensorDataset(Xtr, ytr), batch_size=self.bs, shuffle=True)
        if Xv is not None and yv is not None:
            Xva = torch.from_numpy(Xv).float()
            yva = torch.from_numpy(yv).long()
            va = DataLoader(TensorDataset(Xva, yva), batch_size=self.bs)
        else:
            va = None
        return tr, va

    def fit(self, X_train, y_train, X_val=None, y_val=None):
        train_loader, val_loader = self._make_loaders(X_train, y_train, X_val, y_val)

        opt = torch.optim.AdamW(list(self.backbone.parameters())+list(self.head.parameters()),
                                lr=self.lr, weight_decay=self.weight_decay)
        ce = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
        # scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
        scaler = amp.GradScaler(device="cuda", enabled=self.amp)

        best_metric = float("inf")
        wait = 0
        best_state = None

        for _ in range(self.epochs):
            self.backbone.train(); self.head.train()
            run = 0.0
            for xb, yb in train_loader:
                xb = xb.to(self.device); yb = yb.to(self.device)
                opt.zero_grad(set_to_none=True)
                with amp.autocast(device_type="cuda", enabled=self.amp):
                    z = self.backbone(xb)
                    logits = self.head(z)
                    loss = ce(logits, yb)
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()
                run += float(loss.item())
            train_loss = run / max(1, len(train_loader))

            # simple early stopping on val loss if provided else train loss
            cur = self._val_loss(val_loader, ce) if val_loader is not None else train_loss
            if cur < best_metric:
                best_metric = cur
                wait = 0
                best_state = {
                    "backbone": {k: v.detach().cpu().clone() for k, v in self.backbone.state_dict().items()},
                    "head": {k: v.detach().cpu().clone() for k, v in self.head.state_dict().items()},
                }
            else:
                wait += 1
                if wait >= self.pat:
                    break

        if best_state:
            self.backbone.load_state_dict(best_state["backbone"])
            self.head.load_state_dict(best_state["head"])

    @torch.no_grad()
    def _val_loss(self, val_loader, ce):
        if val_loader is None: return float("inf")
        self.backbone.eval(); self.head.eval()
        run = 0.0; n = 0
        for xb, yb in val_loader:
            xb = xb.to(self.device); yb = yb.to(self.device)
            with amp.autocast(device_type="cuda", enabled=False):
                z = self.backbone(xb)
                logits = self.head(z)
                loss = ce(logits, yb)
            run += float(loss.item()); n += 1
        return run / max(1, n)

    @torch.no_grad()
    def predict_proba(self, X):
        self.backbone.eval(); self.head.eval()
        Xt = torch.from_numpy(X).float().to(self.device)
        probs_all = []
        for i in range(0, Xt.size(0), 4096):
            z = self.backbone(Xt[i:i+4096])
            logits = self.head(z)
            probs = torch.softmax(logits, dim=1)
            probs_all.append(probs.detach().cpu())
        return torch.cat(probs_all, dim=0).numpy()

    @torch.no_grad()
    def predict(self, X):
        return self.predict_proba(X).argmax(axis=1)
