import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tabm import TabM

class TabMClassifier(nn.Module):
    def __init__(self, params, args):
        super().__init__()
        self.device = torch.device("cuda" if args.use_gpu else "cpu")
        self.num_features = args.num_features
        self.num_classes = args.num_classes
        self.k = int(params.get("k", 32))
        self.epochs = int(params.get("epochs", getattr(args, "epochs", 100)))
        self.batch_size = int(params.get("batch_size", getattr(args, "batch_size", 256)))
        self.patience = int(params.get("patience", 20))
        d_out = self.num_classes if self.num_classes > 2 else 1
        self.net = TabM.make(n_num_features=self.num_features, d_out=d_out, k=self.k).to(self.device)
        self.is_multiclass = self.num_classes > 2

    def fit(self, X_train, y_train, X_val=None, y_val=None):
        Xtr = torch.from_numpy(X_train).float().to(self.device)
        ytr = torch.from_numpy(y_train).long().to(self.device)
        train_loader = DataLoader(TensorDataset(Xtr, ytr), batch_size=self.batch_size, shuffle=True)

        crit_ce = nn.CrossEntropyLoss()
        crit_bce = nn.BCEWithLogitsLoss()
        opt = torch.optim.Adam(self.net.parameters(), lr=1e-3)
        best_loss, wait, best_state = float("inf"), 0, None

        for _ in range(self.epochs):
            self.net.train()
            run = 0.0
            for xb, yb in train_loader:
                opt.zero_grad()
                yk = self.net(xb)                    # (B, k, C) or (B, k, 1)
                if self.is_multiclass:
                    loss = sum(crit_ce(yk[:, i, :], yb) for i in range(self.k)) / self.k
                else:
                    yb_f = yb.float()
                    loss = sum(crit_bce(yk[:, i, 0], yb_f) for i in range(self.k)) / self.k
                loss.backward()
                opt.step()
                run += loss.item()
            epoch_loss = run / max(1, len(train_loader))

            # simple early stopping on train loss or val if provided
            cur = self._val_loss(X_val, y_val) if X_val is not None else epoch_loss
            if cur < best_loss:
                best_loss, wait = cur, 0
                best_state = {k: v.detach().cpu().clone() for k, v in self.net.state_dict().items()}
            else:
                wait += 1
                if wait >= self.patience:
                    break

        if best_state:
            self.net.load_state_dict(best_state)

    @torch.no_grad()
    def _val_loss(self, X_val, y_val):
        self.net.eval()
        Xv = torch.from_numpy(X_val).float().to(self.device)
        yv = torch.from_numpy(y_val).long().to(self.device)
        yk = self.net(Xv)
        if self.is_multiclass:
            crit = nn.CrossEntropyLoss()
            return sum(crit(yk[:, i, :], yv).item() for i in range(self.k)) / self.k
        crit = nn.BCEWithLogitsLoss()
        return sum(crit(yk[:, i, 0], yv.float()).item() for i in range(self.k)) / self.k

    @torch.no_grad()
    def predict_proba(self, X):
        self.net.eval()
        Xt = torch.from_numpy(X).float().to(self.device)
        yk = self.net(Xt)
        if self.is_multiclass:
            probs = yk.softmax(dim=-1).mean(dim=1)                 # (N, C)
            return probs.cpu().numpy()
        p1 = torch.sigmoid(yk).mean(dim=1).squeeze(-1)            # (N,)
        p0 = 1.0 - p1
        probs = torch.stack([p0, p1], dim=1)                       # (N, 2)
        return probs.cpu().numpy()

    @torch.no_grad()
    def predict(self, X):
        probs = self.predict_proba(X)
        return probs.argmax(axis=1)
