import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

def l2_normalize(x, eps=1e-8):
    return x / (x.norm(dim=1, keepdim=True) + eps)

class _Backbone(nn.Module):
    def __init__(self, d_in, d_hid=512, d_emb=256, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hid), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_hid, d_hid), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_hid, d_emb),
        )
    def forward(self, x): return self.net(x)

class ModernNCAClassifier(nn.Module):
    """
    ModernNCA-style baseline:
      - Train backbone with in-batch SoftNCA loss under L2 distance (equiv to cosine after norm)
      - Store train embeddings
      - Inference via label-weighted kNN on embeddings with temperature
    Scikit API: fit(X,y[,Xv,yv]), predict_proba(X), predict(X)
    """
    def __init__(self, params, args):
        super().__init__()
        self.device = torch.device("cuda" if args.use_gpu else "cpu")
        self.n_features = args.num_features
        self.n_classes  = args.num_classes

        # hparams
        self.d_hid   = int(params.get("hidden_dim", 512))
        self.d_emb   = int(params.get("emb_dim", 256))
        self.dropout = float(params.get("dropout", 0.1))
        self.epochs  = int(params.get("epochs", getattr(args, "epochs", 100)))
        self.bs      = int(params.get("batch_size", getattr(args, "batch_size", 256)))
        self.lr      = float(params.get("learning_rate", 1e-3))
        self.pat     = int(params.get("patience", 15))
        self.tau     = float(params.get("tau", 0.07))          # temperature for retrieval
        self.k       = int(params.get("k", 64))
        self.knn_bs  = int(params.get("knn_batch", 4096))      # query batch at inference

        self.backbone = _Backbone(self.n_features, self.d_hid, self.d_emb, self.dropout).to(self.device)

        # memory
        self.train_emb = None            # (N,d)
        self.y_train   = None            # (N,)

    def _softnca_loss(self, z, y):
        """
        z: (B,d) L2-normalized
        y: (B,)
        For each anchor i, positives are j with y_j == y_i, j!=i.
        p(j|i) ∝ exp( -||z_i - z_j||^2 / tau )
        Loss = - mean_i log sum_{j in P(i)} p(j|i)
        """
        B = z.size(0)
        with torch.no_grad():
            y_mat = y.unsqueeze(1).expand(B, B)
            pos_mask = (y_mat == y_mat.t()) & (~torch.eye(B, dtype=torch.bool, device=y.device))
        # pairwise squared L2
        d2 = (z**2).sum(1, keepdim=True) + (z**2).sum(1, keepdim=True).T - 2.0 * (z @ z.T)
        logits = -d2 / max(self.tau, 1e-6)                               # (B,B)
        # mask self
        logits = logits.masked_fill(torch.eye(B, device=z.device).bool(), float("-inf"))
        # log-prob over all candidates
        logp_all = logits.log_softmax(dim=1)
        # log-sum over positives
        logp_pos = torch.logsumexp(logp_all.masked_fill(~pos_mask, float("-inf")), dim=1)
        # handle classes with single example in batch: no positives -> ignore
        valid = pos_mask.any(dim=1)
        if valid.any():
            return -(logp_pos[valid].mean())
        # fallback to zero if no positives in batch
        return torch.tensor(0.0, device=z.device, requires_grad=True)

    @torch.no_grad()
    def _embed(self, X):
        self.backbone.eval()
        Xt = torch.from_numpy(X).float().to(self.device)
        out = []
        for i in range(0, Xt.size(0), self.knn_bs):
            z = self.backbone(Xt[i:i+self.knn_bs])
            out.append(l2_normalize(z))
        return torch.cat(out, dim=0)  # (N,d)

    def fit(self, X_train, y_train, X_val=None, y_val=None):
        Xtr = torch.from_numpy(X_train).float().to(self.device)
        ytr = torch.from_numpy(y_train).long().to(self.device)
        train_loader = DataLoader(TensorDataset(Xtr, ytr), batch_size=self.bs, shuffle=True, drop_last=True)

        opt = torch.optim.Adam(self.backbone.parameters(), lr=self.lr)
        best, wait, best_state = float("inf"), 0, None

        for _ in range(self.epochs):
            self.backbone.train()
            run = 0.0
            for xb, yb in train_loader:
                opt.zero_grad()
                z = l2_normalize(self.backbone(xb))
                loss = self._softnca_loss(z, yb)
                loss.backward()
                opt.step()
                run += float(loss.item())

            cur = run / max(1, len(train_loader))
            if cur < best:
                best, wait = cur, 0
                best_state = {k: v.detach().cpu().clone() for k, v in self.backbone.state_dict().items()}
            else:
                wait += 1
                if wait >= self.pat:
                    break

        if best_state:
            self.backbone.load_state_dict(best_state)

        with torch.no_grad():
            self.train_emb = self._embed(X_train)                   # (N,d)
            self.y_train   = torch.from_numpy(y_train).long().to(self.device)

    @torch.no_grad()
    def _knn_probs(self, Q):
        """
        Q: (B,d) normalized on GPU
        Returns class probabilities (B,C) via kNN with temperature tau.
        """
        Kmat = self.train_emb                                     # (N,d)
        d2 = (
            (Q**2).sum(1, keepdim=True)
            + (Kmat**2).sum(1, keepdim=True).T
            - 2.0 * (Q @ Kmat.T)
        )                                                         # (B,N)
        k = min(self.k, Kmat.size(0))
        dist, idx = torch.topk(d2, k=k, dim=1, largest=False)     # (B,k)
        w = torch.softmax(-dist / max(self.tau, 1e-6), dim=1)     # (B,k)
        yk = self.y_train[idx]                                    # (B,k)
        B = Q.size(0)
        probs = torch.zeros(B, self.n_classes, device=Q.device)
        for c in range(self.n_classes):
            probs[:, c] = (w * (yk == c).float()).sum(dim=1)
        # renormalize for safety
        probs = probs / (probs.sum(dim=1, keepdim=True) + 1e-12)
        return probs

    @torch.no_grad()
    def predict_proba(self, X):
        self.backbone.eval()
        Xt = torch.from_numpy(X).float().to(self.device)
        outs = []
        for i in range(0, Xt.size(0), self.knn_bs):
            z = l2_normalize(self.backbone(Xt[i:i+self.knn_bs]))
            probs = self._knn_probs(z)
            outs.append(probs.detach().cpu())
        return torch.cat(outs, dim=0).numpy()

    @torch.no_grad()
    def predict(self, X):
        return self.predict_proba(X).argmax(axis=1)
