import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import pytorch_lightning as pl

class LinearProbeCallback(pl.Callback):
    """
    Every N epochs:
      - freeze teacher backbone
      - extract CLS features
      - train a linear classifier for a few epochs (fast)
      - log val accuracy to W&B/CSV
    """
    def __init__(self, every_n_epochs=100, max_epochs=5, lr=0.1, weight_decay=0.0,
                 max_train_batches=200, max_val_batches=50, batch_size=256,
                 enabled=True, run_before_training=True):
        self.every_n_epochs = every_n_epochs
        self.max_epochs = max_epochs
        self.lr = lr
        self.weight_decay = weight_decay
        self.max_train_batches = max_train_batches
        self.max_val_batches = max_val_batches
        self.batch_size = batch_size
        self.enabled = enabled
        self.run_before_training = run_before_training

    @staticmethod
    def _cls_features(backbone, images):
        # Support both ViT and ResNet-style backbones.
        # For ViT wrappers (e.g., MaskedVisionTransformerTIMM) prefer `backbone.vit.forward_features`.
        vit = getattr(backbone, "vit", None)
        if vit is not None:
            out = vit.forward_features(images)
            if out.dim() == 3:
                return out[:, 0]
            return out

        # For other models, prefer `forward_features` if present, otherwise call the backbone directly.
        if hasattr(backbone, "forward_features"):
            out = backbone.forward_features(images)
        else:
            out = backbone(images)

        # Flatten spatial dims if present
        if out.dim() > 2:
            out = torch.flatten(out, 1)
        return out
    
    # ---------- run once before training ----------
    def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        if not (self.enabled and self.run_before_training):
            return
        # "epoch 0 probe" before any optimizer steps
        self._run_probe(trainer, pl_module, epoch_to_log=0)

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        if not self.enabled:
            return

        epoch_to_log = trainer.current_epoch + 1  # end-of-epoch numbering
        if epoch_to_log % self.every_n_epochs != 0:
            return

        self._run_probe(trainer, pl_module, epoch_to_log=epoch_to_log)

    def _run_probe(self, trainer: pl.Trainer, pl_module: pl.LightningModule, epoch_to_log: int):
        dm = trainer.datamodule
        if dm is None:
            return

        # Synchronize all ranks before running the expensive probe on rank 0.
        if trainer.strategy is not None:
            trainer.strategy.barrier()
        if not trainer.is_global_zero:
            if trainer.strategy is not None:
                trainer.strategy.barrier()
            return

        # IMPORTANT: preserve mode and set eval for stable features
        was_training = pl_module.training
        pl_module.eval()

        try:
            # Prefer explicit probe dataloaders if provided by the datamodule
            train_loader = dm.probe_train_dataloader() if hasattr(dm, "probe_train_dataloader") else dm.train_dataloader()
            val_loader = dm.probe_test_dataloader() if hasattr(dm, "probe_test_dataloader") else dm.val_dataloader()
            device = pl_module.device

            # num classes
            num_classes = getattr(dm, "num_classes", None) or getattr(dm, "n_classes", None)
            if num_classes is None:
                ds = getattr(dm, "ds_train", None) or getattr(dm, "train_set", None)
                if ds is not None and hasattr(ds, "features") and "label" in ds.features:
                    try:
                        num_classes = int(ds.features["label"].num_classes)
                    except Exception:
                        pass
            if num_classes is None:
                # last resort: infer from a batch
                _, y0 = next(iter(train_loader))
                num_classes = int(y0.max().item() + 1)

            backbone = pl_module.backbone  # frozen for probe

            # feature dim
            with torch.no_grad():
                x0, _ = next(iter(train_loader))
                images0 = x0[0].to(device, non_blocking=True) if isinstance(x0, (list, tuple)) else x0.to(device)
                feat0 = self._cls_features(backbone, images0)
                feat_dim = feat0.shape[-1]

            # Extract fresh features EVERY probe run
            Xtr, Ytr = self.extract_features(train_loader, backbone, device, max_batches=self.max_train_batches)
            Xva, Yva = self.extract_features(val_loader, backbone, device, max_batches=self.max_val_batches)

            train_dl = DataLoader(TensorDataset(Xtr, Ytr), batch_size=self.batch_size, shuffle=True, num_workers=0)
            val_dl = DataLoader(TensorDataset(Xva, Yva), batch_size=self.batch_size, shuffle=False, num_workers=0)

            clf = nn.Linear(feat_dim, num_classes).to(device)
            opt = torch.optim.AdamW(clf.parameters(), lr=self.lr, weight_decay=self.weight_decay)

            def run_epoch_cached(dl, train=True):
                correct, total = 0, 0
                loss_sum = 0.0
                clf.train() if train else clf.eval()

                for X, y in dl:
                    X = X.to(device, non_blocking=True)
                    y = y.to(device, non_blocking=True)

                    logits = clf(X)
                    loss = F.cross_entropy(logits, y)  # mean over batch

                    if train:
                        opt.zero_grad(set_to_none=True)
                        loss.backward()
                        opt.step()

                    bs = y.size(0)
                    correct += (logits.argmax(dim=1) == y).sum().item()
                    total += bs
                    loss_sum += loss.item() * bs  # sample-weighted

                return correct / max(total, 1), loss_sum / max(total, 1)

            for _ in range(self.max_epochs):
                train_acc, train_loss = run_epoch_cached(train_dl, train=True)

            val_acc, val_loss = run_epoch_cached(val_dl, train=False)

            pl_module.log("probe/train_acc", train_acc, on_step=False, on_epoch=True, sync_dist=False)
            pl_module.log("probe/train_loss", train_loss, on_step=False, on_epoch=True, sync_dist=False)
            pl_module.log("probe/val_acc", val_acc, on_step=False, on_epoch=True, sync_dist=False)
            pl_module.log("probe/val_loss", val_loss, on_step=False, on_epoch=True, sync_dist=False)

            trainer.print(f"[LinearProbe] epoch={epoch_to_log} train_acc={train_acc:.4f} val_acc={val_acc:.4f}")

        finally:
            # restore the exact previous mode
            if was_training:
                pl_module.train()
            # signal completion to other ranks
            if trainer.strategy is not None:
                trainer.strategy.barrier()


    def extract_features(self, loader, backbone, device, max_batches=999999):
        feats_list, y_list = [], []

        for batch_idx, (views, y) in enumerate(loader):
            if batch_idx >= max_batches:
                break
            images = views[0].to(device, non_blocking=True) if isinstance(views, (list, tuple)) else views.to(device)

            with torch.no_grad():
                feats = self._cls_features(backbone, images)
                feats = F.normalize(feats, dim=1)

            feats_list.append(feats.cpu())
            y_list.append(y.cpu())

        return torch.cat(feats_list, dim=0), torch.cat(y_list, dim=0)
