from __future__ import annotations

from abc import ABC, abstractmethod
import copy
import hashlib
import os
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset

from .helpers import CGInfluenceModule


# -----------------------
# helpers
# -----------------------




class _IndexedDataset(Dataset):
    """Wrap a dataset so each item becomes (x, y, idx).

    The returned idx is the index within this dataset (i.e., 0..len(self)-1).
    This matches fold indices produced by make_folds(len(dataset), ...).
    """

    def __init__(self, ds: Dataset):
        self.ds = ds

    def __len__(self) -> int:
        return len(self.ds)

    def __getitem__(self, idx: int):
        item = self.ds[idx]
        if isinstance(item, (tuple, list)) and len(item) >= 2:
            x, y = item[0], item[1]
        else:
            raise ValueError("Dataset item must be a (x,y) tuple/list for PBRF.")
        return x, y, idx


def _loss_per_example(loss_fn: nn.Module, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Compute per-example loss vector of shape (B,).

    Tries to temporarily set loss_fn.reduction='none'. Falls back to calling with
    reduction='none' (functional-style), then to a slow per-example loop.
    """

    if hasattr(loss_fn, "reduction"):
        prev = getattr(loss_fn, "reduction")
        try:
            setattr(loss_fn, "reduction", "none")
            loss = loss_fn(outputs, targets)
        finally:
            setattr(loss_fn, "reduction", prev)
    else:
        try:
            loss = loss_fn(outputs, targets, reduction="none")  # type: ignore[misc]
        except TypeError:
            # Last resort: loop (slow, but keeps baseline usable)
            losses: List[torch.Tensor] = []
            for i in range(int(outputs.shape[0])):
                li = loss_fn(outputs[i : i + 1], targets[i : i + 1]).reshape(())
                losses.append(li)
            loss = torch.stack(losses, dim=0)

    if loss.ndim == 0:
        # Unexpected; coerce to (B,) by repeating
        return loss.expand(int(outputs.shape[0]))
    if loss.shape[0] != outputs.shape[0]:
        # E.g., MSE reduction='none' returns (B, ...) – reduce over non-batch dims.
        loss = loss.view(int(outputs.shape[0]), -1).mean(dim=1)
    elif loss.ndim > 1:
        loss = loss.view(int(outputs.shape[0]), -1).mean(dim=1)
    return loss


def _mean_loss_over_loader(model: nn.Module, loader: DataLoader, loss_fn: nn.Module, device: torch.device) -> float:
    model.eval()
    model.to(device)
    total = 0.0
    n = 0
    with torch.no_grad():
        for batch in loader:
            if isinstance(batch, (tuple, list)) and len(batch) >= 2:
                xb, yb = batch[0], batch[1]
            else:
                raise ValueError("Loader must yield (x,y) or (x,y,idx) batches.")
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            out = model(xb)
            per_ex = _loss_per_example(loss_fn, out, yb)
            total += float(per_ex.sum().item())
            n += int(per_ex.shape[0])
    if n == 0:
        raise RuntimeError("Empty loader in _mean_loss_over_loader")
    return total / float(n)


def _clone_loader_with_index(dataset: Dataset, like: DataLoader, *, shuffle: bool) -> DataLoader:
    return DataLoader(
        _IndexedDataset(dataset),
        batch_size=getattr(like, "batch_size", 128) or 128,
        shuffle=shuffle,
        num_workers=int(getattr(like, "num_workers", 0) or 0),
        pin_memory=bool(getattr(like, "pin_memory", False)),
        persistent_workers=bool(getattr(like, "persistent_workers", False)),
        drop_last=bool(getattr(like, "drop_last", False)),
    )




def _mean_grad_over_loader(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    device: torch.device,
) -> torch.Tensor:
    model.eval()
    model.to(device)
    g_sum: Optional[torch.Tensor] = None
    n = 0
    for batch in loader:
        if isinstance(batch, (tuple, list)) and len(batch) >= 2:
            xb, yb = batch[0], batch[1]
        else:
            raise ValueError("Loader must yield (x,y) or (x,y,idx) batches.")
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        bsz = int(xb.shape[0])
        model.zero_grad(set_to_none=True)
        out = model(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        g = _flatten_grads(model)
        # loss gradients are typically averaged over the batch; convert to per-example mean.
        g_sum = (g * float(bsz)) if g_sum is None else (g_sum + g * float(bsz))
        n += bsz
    if g_sum is None or n == 0:
        raise RuntimeError("Empty loader in _mean_grad_over_loader.")
    model.zero_grad(set_to_none=True)
    return (g_sum / float(n)).detach()


def _per_example_grad_dot(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    direction: torch.Tensor,
    device: torch.device,
    selected_mask: Optional[np.ndarray] = None,
    scale: float = 1.0,
    sign: float = 1.0,
) -> np.ndarray:
    model.eval()
    model.to(device)
    direction = direction.detach().to(device)

    n = len(loader.dataset)
    # If only a subset is scored (selected_mask), leave others as NaN so they
    # are ignored by selection/ranking downstream.
    out_scores = np.full(n, np.nan, dtype=np.float64) if selected_mask is not None else np.zeros(n, dtype=np.float64)

    cursor = 0
    for batch in loader:
        if isinstance(batch, (tuple, list)) and len(batch) >= 2:
            xb, yb = batch[0], batch[1]
            idxb = batch[2] if len(batch) >= 3 else None
        else:
            raise ValueError("Loader must yield (x,y) or (x,y,idx) batches.")

        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        bsz = int(xb.shape[0])

        if idxb is not None:
            idxb_np = idxb.detach().cpu().numpy().astype(np.int64)
        else:
            idxb_np = None

        for j in range(bsz):
            idx = int(idxb_np[j]) if idxb_np is not None else int(cursor + j)
            if selected_mask is not None and not bool(selected_mask[idx]):
                continue
            model.zero_grad(set_to_none=True)
            out = model(xb[j : j + 1])
            loss = loss_fn(out, yb[j : j + 1])
            loss.backward()
            g = _flatten_grads(model)
            out_scores[idx] = float(sign * scale * torch.dot(g, direction).item())
        cursor += bsz
    model.zero_grad(set_to_none=True)
    return out_scores


def _per_example_scores(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    direction: torch.Tensor,
    total_train: int,
    device: torch.device,
) -> np.ndarray:
    # Use an indexed loader to avoid relying on iteration order.
    indexed_loader = _clone_loader_with_index(loader.dataset, loader, shuffle=False)
    return _per_example_grad_dot(
        model=model,
        loader=indexed_loader,
        loss_fn=loss_fn,
        direction=direction,
        device=device,
        selected_mask=None,
        scale=(1.0 / max(1, total_train)),
        sign=-1.0,
    )


# -----------------------
# baseline protocol
# -----------------------

class Baseline(ABC):
    name: str

    @abstractmethod
    def prepare(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        loss_fn: nn.Module,
        cg_mod: CGInfluenceModule,
        caches: Dict[str, Dict[str, Any]],
        device: torch.device,
        cfg: Dict[str, Any],
        folds: Optional[List[np.ndarray]] = None,
    ) -> None:
        ...

    @abstractmethod
    def delta_theta_subset(self, S: np.ndarray) -> torch.Tensor:
        ...

    def score_points(self) -> Optional[np.ndarray]:
        return None

    def score_folds(self, folds: List[np.ndarray]) -> Optional[np.ndarray]:
        """Optional: aggregate point scores over each fold.

        Default behavior: if point scores exist, return sum over fold indices.
        """
        scores = self.score_points()
        if scores is None:
            return None
        out = np.full(len(folds), np.nan, dtype=np.float64)
        for i, S in enumerate(folds):
            S = np.asarray(S, dtype=np.int64)
            if S.size == 0:
                continue
            vals = np.asarray(scores, dtype=float)[S]
            m = np.isfinite(vals)
            if np.any(m):
                out[i] = float(np.sum(vals[m]))
        return out


# -----------------------
# IF one-solve baseline
# -----------------------

@dataclass
class IFOneSolveBaseline(Baseline):
    name: str = "if_one_solve"
    batch_size: int = 32

    def __post_init__(self) -> None:
        self.device: Optional[torch.device] = None
        self.train_loader: Optional[DataLoader] = None
        self.val_loader: Optional[DataLoader] = None
        self.loss_fn: Optional[nn.Module] = None
        self.cg_mod: Optional[CGInfluenceModule] = None
        self.ihvp_val: Optional[torch.Tensor] = None
        self.n_train: int = 0

    def prepare(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        loss_fn: nn.Module,
        cg_mod: CGInfluenceModule,
        caches: Dict[str, Dict[str, Any]],
        device: torch.device,
        cfg: Dict[str, Any],
        folds: Optional[List[np.ndarray]] = None,
    ) -> None:
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader if val_loader is not None else train_loader
        self.loss_fn = loss_fn
        self.cg_mod = cg_mod
        self.n_train = len(train_loader.dataset)
        shared = cfg.get("shared", {}) if isinstance(cfg, dict) else {}
        g_val = shared.get("g_val")
        s = shared.get("ihvp_val")
        if g_val is None:
            g_val = _mean_grad_over_loader(model, self.val_loader, loss_fn, device)
        if s is None:
            ihvp_res = cg_mod.inverse_hvp(g_val)
            s = ihvp_res.get("ihvp")
        if s is None:
            raise RuntimeError("CGInfluenceModule.inverse_hvp did not return 'ihvp' for FastIF.")
        s = s.detach()
        ihvp_res = cg_mod.inverse_hvp(g_val)
        ihvp = ihvp_res.get("ihvp")
        if ihvp is None:
            raise RuntimeError("CGInfluenceModule.inverse_hvp did not return 'ihvp'.")
        self.ihvp_val = ihvp.detach()

    def delta_theta_subset(self, S: np.ndarray) -> torch.Tensor:
        raise NotImplementedError("Subset delta is not supported for IF one-solve; use score_points instead.")

    def score_points(self) -> Optional[np.ndarray]:
        if self.train_loader is None or self.loss_fn is None or self.cg_mod is None or self.ihvp_val is None:
            raise RuntimeError("IFOneSolveBaseline not prepared before scoring.")
        return _per_example_scores(
            model=self.cg_mod.model,
            loader=self.train_loader,
            loss_fn=self.loss_fn,
            direction=self.ihvp_val,
            total_train=self.n_train,
            device=self.device or next(self.cg_mod.model.parameters()).device,
        )


# -----------------------
# TRAK baseline
# -----------------------

def _theta_fingerprint(model: nn.Module) -> str:
    vec = torch.nn.utils.parameters_to_vector([p.detach() for p in model.parameters()]).detach()
    vec = vec.float().cpu().numpy().tobytes()
    return hashlib.sha256(vec).hexdigest()


def _import_trak_traker():
    """Import TRAKer.

    This tries `from trak import TRAKer`, then `from trak.traker import TRAKer`.
    TRAK is an optional dependency; if missing, we raise ImportError with a
    helpful message (no auto-install side effects).
    """

    try:
        from trak import TRAKer  # type: ignore

        return TRAKer
    except Exception:
        try:
            from trak.traker import TRAKer  # type: ignore

            return TRAKer
        except Exception as exc:
            raise ImportError(
                "TRAK is not installed or could not be imported. Install with `pip install trak`."
            ) from exc


def _load_trak_checkpoint_state_dict(path: str) -> Dict[str, torch.Tensor]:
    """Load a checkpoint file and return a pure model state_dict.

    TRAK expects a dict-like state_dict (parameter_name -> tensor).
    Our training checkpoints are often wrapper dicts, so we unwrap common keys.
    """

    ckpt = torch.load(os.fspath(path), map_location=torch.device("cpu"))
    if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
        state = ckpt["state_dict"]
    elif isinstance(ckpt, dict) and "model_state_dict" in ckpt and isinstance(ckpt["model_state_dict"], dict):
        state = ckpt["model_state_dict"]
    elif isinstance(ckpt, dict) and all(isinstance(k, str) for k in ckpt.keys()):
        # Heuristic: already looks like a state_dict.
        state = ckpt
    else:
        raise ValueError(f"Unrecognized checkpoint format for TRAK at: {path}")
    # Narrow type for callers.
    return {str(k): v for k, v in state.items() if torch.is_tensor(v)}


def _select_checkpoints(paths: List[str], max_n: int) -> List[str]:
    paths = [str(p) for p in paths if p]
    if not paths:
        return []
    max_n = int(max_n)
    if max_n <= 0 or len(paths) <= max_n:
        return paths
    # Evenly spaced selection (deterministic)
    idx = np.linspace(0, len(paths) - 1, num=max_n)
    idx = np.unique(np.round(idx).astype(int))
    return [paths[int(i)] for i in idx.tolist()]


def _indexed_like(loader: DataLoader, *, shuffle: bool) -> DataLoader:
    return DataLoader(
        _IndexedDataset(loader.dataset),
        batch_size=getattr(loader, "batch_size", 128) or 128,
        shuffle=shuffle,
        num_workers=int(getattr(loader, "num_workers", 0) or 0),
        pin_memory=bool(getattr(loader, "pin_memory", False)),
        persistent_workers=bool(getattr(loader, "persistent_workers", False)),
        drop_last=bool(getattr(loader, "drop_last", False)),
    )


@dataclass
class TRAKBaseline(Baseline):
    """TRAK baseline: amortized point attributions for a fixed target set.

    Produces per-point scores s_i (len = n_train). Fold/group scores can be
    computed by summing s_i within each fold (see Baseline.score_folds).

    Notes:
    - Requires multiple checkpoints.
    - This baseline does NOT implement delta-theta.
    """

    name: str = "trak"
    checkpoint_paths: List[str] = field(default_factory=list)
    use_val_loader: bool = True

    _scores: Optional[np.ndarray] = None
    _prepared: bool = False
    _traker: Any = None
    _ckpts: List[str] = field(default_factory=list)
    _fingerprints: Dict[int, str] = field(default_factory=dict)
    _exp_name_base: str = "trak"
    _trak_cfg: Dict[str, Any] = field(default_factory=dict)
    _device: Optional[torch.device] = None
    _n_train: int = 0
    _score_calls: int = 0

    def prepare(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        loss_fn: nn.Module,
        cg_mod: Any,
        caches: Dict[str, Dict[str, Any]],
        device: torch.device,
        cfg: Dict[str, Any],
        folds: Optional[List[np.ndarray]] = None,
    ) -> None:
        if self._prepared and self._scores is not None and self._traker is not None:
            return

        trak_cfg = dict(cfg.get("trak", {})) if isinstance(cfg, dict) else {}
        self._trak_cfg = dict(trak_cfg)

        # Target set: usually clean validation.
        target_loader = val_loader if (self.use_val_loader and val_loader is not None) else train_loader

        # Checkpoint discovery.
        if not self.checkpoint_paths:
            self.checkpoint_paths = list(trak_cfg.get("checkpoint_paths", []))
        ckpt_paths = list(self.checkpoint_paths)

        proj_dim = int(trak_cfg.get("proj_dim", 2048))
        task = str(trak_cfg.get("task", "image_classification"))
        save_dir = trak_cfg.get("save_dir", None)
        if save_dir is None:
            raise ValueError("TRAK requires cfg['trak']['save_dir'] to be set (a per-run folder).")
        save_dir = str(save_dir)
        Path(save_dir).mkdir(parents=True, exist_ok=True)

        num_ckpt_max = int(trak_cfg.get("num_checkpoints_max", 5))
        ckpts = _select_checkpoints([os.fspath(p) for p in ckpt_paths], num_ckpt_max)
        if len(ckpts) == 0:
            raise RuntimeError("TRAK enabled but no checkpoint paths were provided.")

        n_train = int(len(train_loader.dataset))
        if n_train <= 0:
            raise RuntimeError("Empty train_loader for TRAK")
        self._n_train = int(n_train)

        # Always score/featurize with deterministic indices.
        train_idx_loader = _indexed_like(train_loader, shuffle=False)
        target_idx_loader = _indexed_like(target_loader, shuffle=False)

        TRAKer = _import_trak_traker()

        # IMPORTANT: do not mutate the caller's model weights by loading checkpoints.
        trak_model = copy.deepcopy(model).eval().to(device)
        traker = TRAKer(
            model=trak_model,
            task=task,
            train_set_size=n_train,
            proj_dim=proj_dim,
            save_dir=save_dir,
        )

        exp_name = str(trak_cfg.get("exp_name", "trak"))
        self._exp_name_base = exp_name

        # Verify a stable model_id <-> checkpoint mapping via a simple fingerprint.
        fingerprints: Dict[int, str] = {}

        for j, ckpt in enumerate(ckpts):
            ckpt = os.fspath(ckpt)
            model_id = int(j)
            sd = _load_trak_checkpoint_state_dict(ckpt)
            traker.load_checkpoint(sd, model_id=model_id)
            fingerprints[model_id] = _theta_fingerprint(trak_model)

            for xb, yb, idxb in train_idx_loader:
                traker.featurize(
                    batch=(xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)),
                    inds=idxb.detach().cpu().numpy().astype(np.int64),
                )

        traker.finalize_features()

        for j, ckpt in enumerate(ckpts):
            ckpt = os.fspath(ckpt)
            model_id = int(j)
            sd = _load_trak_checkpoint_state_dict(ckpt)
            traker.load_checkpoint(sd, model_id=model_id)
            fp2 = _theta_fingerprint(trak_model)
            fp1 = fingerprints.get(model_id, "")
            if fp1 and fp2 != fp1:
                raise RuntimeError(
                    f"TRAK checkpoint/model_id mismatch for model_id={model_id}: {fp1[:12]}.. vs {fp2[:12]}.."
                )

            traker.start_scoring_checkpoint(
                exp_name=exp_name,
                checkpoint=sd,
                model_id=model_id,
                num_targets=int(len(target_idx_loader.dataset)),
            )

            for xb, yb, idxb in target_idx_loader:
                traker.score(
                    batch=(xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)),
                    inds=idxb.detach().cpu().numpy().astype(np.int64),
                )

        scores_raw = traker.finalize_scores(exp_name=exp_name)
        scores_arr = np.asarray(scores_raw)

        # TRAK commonly returns a (n_train, n_targets) matrix; reduce over targets.
        if scores_arr.ndim == 2:
            if scores_arr.shape[0] == n_train:
                mat = scores_arr
            elif scores_arr.shape[1] == n_train:
                mat = scores_arr.T
            else:
                raise RuntimeError(
                    f"TRAK scores have wrong shape: {scores_arr.shape}, expected ({n_train},) or ({n_train}, n_targets)"
                )

            agg = str(trak_cfg.get("target_agg", "sum")).lower()
            if agg == "sum":
                scores = np.nansum(mat, axis=1)
            elif agg == "mean":
                scores = np.nanmean(mat, axis=1)
            else:
                raise ValueError(f"Unknown trak.target_agg={agg!r} (expected 'sum' or 'mean')")

            scores = np.asarray(scores, dtype=np.float64)
        elif scores_arr.ndim == 1:
            scores = np.asarray(scores_arr, dtype=np.float64)
        else:
            raise RuntimeError(
                f"TRAK scores have unsupported ndim={scores_arr.ndim} with shape {scores_arr.shape}"
            )

        if scores.shape != (n_train,):
            raise RuntimeError(f"TRAK scores have wrong shape after reduction: {scores.shape}, expected ({n_train},)")

        self._scores = scores
        self._traker = traker
        self._ckpts = list(ckpts)
        self._fingerprints = dict(fingerprints)
        self._device = device
        self._prepared = True

    def _score_points_impl(self, val_loader: DataLoader, *, exp_name: str) -> np.ndarray:
        if self._traker is None or not self._prepared:
            raise RuntimeError("TRAKBaseline not prepared.")
        traker = self._traker
        trak_cfg = dict(self._trak_cfg or {})
        n_train = int(self._n_train)

        target_idx_loader = _indexed_like(val_loader, shuffle=False)

        for j, ckpt in enumerate(self._ckpts):
            ckpt = os.fspath(ckpt)
            model_id = int(j)
            sd = _load_trak_checkpoint_state_dict(ckpt)
            traker.load_checkpoint(sd, model_id=model_id)

            # Best-effort safety check: the same checkpoint id should map to the same theta.
            try:
                fp2 = _theta_fingerprint(traker.model) if hasattr(traker, "model") else ""
                fp1 = self._fingerprints.get(model_id, "")
                if fp1 and fp2 and fp2 != fp1:
                    raise RuntimeError(
                        f"TRAK checkpoint/model_id mismatch for model_id={model_id}: {fp1[:12]}.. vs {fp2[:12]}.."
                    )
            except Exception:
                pass

            traker.start_scoring_checkpoint(
                exp_name=exp_name,
                checkpoint=sd,
                model_id=model_id,
                num_targets=int(len(target_idx_loader.dataset)),
            )

            dev = self._device
            if dev is None:
                raise RuntimeError("TRAKBaseline missing device")

            for xb, yb, idxb in target_idx_loader:
                traker.score(
                    batch=(xb.to(dev, non_blocking=True), yb.to(dev, non_blocking=True)),
                    inds=idxb.detach().cpu().numpy().astype(np.int64),
                )

        scores_raw = traker.finalize_scores(exp_name=exp_name)
        scores_arr = np.asarray(scores_raw)

        # TRAK commonly returns a (n_train, n_targets) matrix; reduce over targets.
        if scores_arr.ndim == 2:
            if scores_arr.shape[0] == n_train:
                mat = scores_arr
            elif scores_arr.shape[1] == n_train:
                mat = scores_arr.T
            else:
                raise RuntimeError(
                    f"TRAK scores have wrong shape: {scores_arr.shape}, expected ({n_train},) or ({n_train}, n_targets)"
                )
            agg = str(trak_cfg.get("target_agg", "sum")).lower()
            if agg == "sum":
                scores = np.nansum(mat, axis=1)
            elif agg == "mean":
                scores = np.nanmean(mat, axis=1)
            else:
                raise ValueError(f"Unknown trak.target_agg={agg!r} (expected 'sum' or 'mean')")
            scores = np.asarray(scores, dtype=np.float64)
        elif scores_arr.ndim == 1:
            scores = np.asarray(scores_arr, dtype=np.float64)
        else:
            raise RuntimeError(
                f"TRAK scores have unsupported ndim={scores_arr.ndim} with shape {scores_arr.shape}"
            )

        if scores.shape != (n_train,):
            raise RuntimeError(f"TRAK scores have wrong shape after reduction: {scores.shape}, expected ({n_train},)")
        return scores

    def delta_theta_subset(self, S: np.ndarray) -> torch.Tensor:
        raise NotImplementedError("TRAK does not implement delta-theta; use score_points/score_folds.")

    def score_points(self, val_loader: Optional[DataLoader] = None) -> Optional[np.ndarray]:
        """Return per-train-point scores.

        If val_loader is provided, recompute scores for that target set while reusing
        cached TRAIN features from prepare().
        """
        if not self._prepared or self._traker is None:
            raise RuntimeError("TRAKBaseline not prepared.")
        if val_loader is None:
            if self._scores is None:
                raise RuntimeError("TRAKBaseline missing default scores.")
            return self._scores

        self._score_calls += 1
        exp_name = f"{self._exp_name_base}_override_{self._score_calls}"
        return self._score_points_impl(val_loader, exp_name=exp_name)


# -----------------------------
# Feature extraction (penultimate)
# -----------------------------

def _flatten_grads(model: nn.Module) -> torch.Tensor:
    grads = []
    for p in model.parameters():
        if p.grad is None:
            grads.append(torch.zeros_like(p).view(-1))
        else:
            grads.append(p.grad.detach().view(-1))
    return torch.cat(grads, dim=0)


@dataclass
class RandomBaseline(Baseline):
    name: str = "random"
    seed: int = 0
    _scores: Optional[np.ndarray] = None

    def prepare(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        loss_fn: nn.Module,
        cg_mod: Any,
        caches: Dict[str, Dict[str, Any]],
        device: torch.device,
        cfg: Dict[str, Any],
        folds: Optional[List[np.ndarray]] = None,
    ) -> None:
        n = len(train_loader.dataset)
        s = int(cfg.get("seed", self.seed))
        rng = np.random.default_rng(s)
        self._scores = rng.standard_normal(n).astype(np.float64)

    def delta_theta_subset(self, S: np.ndarray) -> torch.Tensor:
        raise NotImplementedError

    def score_points(self) -> Optional[np.ndarray]:
        if self._scores is None:
            raise RuntimeError("RandomBaseline not prepared.")
        return self._scores


@dataclass
class BestBaseline(Baseline):
    """
    Oracle upper bound for poisoning detection / cleaning:
    score=1 for known poisoned train indices, else 0.
    Requires runner to pass cfg["oracle_poison_in_train"] = List[int] (local train indices).
    """
    name: str = "best"
    _scores: Optional[np.ndarray] = None

    def prepare(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        loss_fn: nn.Module,
        cg_mod: Any,
        caches: Dict[str, Dict[str, Any]],
        device: torch.device,
        cfg: Dict[str, Any],
        folds: Optional[List[np.ndarray]] = None,
    ) -> None:
        n = len(train_loader.dataset)
        idxs = cfg.get("oracle_poison_in_train", [])  # MUST be local to train_loader.dataset indexing
        idxs = np.asarray(list(idxs), dtype=np.int64)
        scores = np.zeros(n, dtype=np.float64)
        if idxs.size:
            idxs = idxs[(idxs >= 0) & (idxs < n)]
            scores[idxs] = 1.0
        self._scores = scores

    def delta_theta_subset(self, S: np.ndarray) -> torch.Tensor:
        raise NotImplementedError

    def score_points(self) -> Optional[np.ndarray]:
        if self._scores is None:
            raise RuntimeError("BestBaseline not prepared.")
        return self._scores



@dataclass
class IFCPointBaseline(Baseline):
    name: str = "ifc_point"
    cache_name: Optional[str] = None

    _scores: Optional[np.ndarray] = None
    _cache: Optional[Dict[str, Any]] = None

    def _pick_cache(self, caches: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
        if self.cache_name and self.cache_name in caches:
            return caches[self.cache_name]
        if not caches:
            raise ValueError("IFCPointBaseline needs at least one cache in `caches`.")
        return next(iter(caches.values()))

    def prepare(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        loss_fn: nn.Module,
        cg_mod: Any,
        caches: Dict[str, Dict[str, Any]],
        device: torch.device,
        cfg: Dict[str, Any],
        folds: Optional[List[np.ndarray]] = None,
    ) -> None:
        cache = self._pick_cache(caches)
        self._cache = cache
        labels = np.asarray(cache.get("labels", []), dtype=np.int64)
        v_c: List[torch.Tensor] = cache.get("v_c", [])
        if labels.size == 0 or not v_c:
            raise ValueError("IFCPointBaseline cache missing labels or v_c.")
        n_train = len(train_loader.dataset)
        if labels.shape[0] != n_train:
            raise ValueError(f"Cache labels length {labels.shape[0]} != train set {n_train}")
        shared = cfg.get("shared", {}) if isinstance(cfg, dict) else {}
        N = int(cache.get("N", n_train))
        val_ld = val_loader if val_loader is not None else train_loader
        g_val = shared.get("g_val")
        if g_val is None:
            g_val = _mean_grad_over_loader(model, val_ld, loss_fn, device)
        g_val = g_val.to(device)

        cluster_scores = []
        for v in v_c:
            score_c = float(torch.dot(v.to(device), g_val).item()) / float(max(1, N))
            cluster_scores.append(score_c)
        cluster_scores_arr = np.asarray(cluster_scores, dtype=float)
        scores = cluster_scores_arr[labels]
        self._scores = scores

    def delta_theta_subset(self, S: np.ndarray) -> torch.Tensor:
        if self._cache is None:
            raise RuntimeError("IFCPointBaseline not prepared; missing cache")
        from .if_compressed import delta_ifc_from_cache

        S = np.asarray(S, dtype=np.int64)
        # Prefer moving delta to the same device as the cached v_c tensors.
        v0 = self._cache["v_c"][0]
        dev = v0.device
        out = delta_ifc_from_cache(self._cache, S.tolist(), device=dev)
        delta = out.get("delta_theta")
        if delta is None:
            raise RuntimeError("delta_ifc_from_cache did not return delta_theta")
        return delta.detach()

    def score_points(self) -> Optional[np.ndarray]:
        if self._scores is None:
            raise RuntimeError("IFCPointBaseline not prepared.")
        return self._scores


# -----------------------
# Exact IF point baseline
# -----------------------

@dataclass
class IFBaseline(Baseline):
    name: str = "if"
    batch_size: int = 32

    def __post_init__(self) -> None:
        self.device: Optional[torch.device] = None
        self.train_loader: Optional[DataLoader] = None
        self.val_loader: Optional[DataLoader] = None
        self.loss_fn: Optional[nn.Module] = None
        self.cg_mod: Optional[CGInfluenceModule] = None
        self.ihvp_val: Optional[torch.Tensor] = None
        self.n_train: int = 0

    def prepare(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        loss_fn: nn.Module,
        cg_mod: CGInfluenceModule,
        caches: Dict[str, Dict[str, Any]],
        device: torch.device,
        cfg: Dict[str, Any],
        folds: Optional[List[np.ndarray]] = None,
    ) -> None:
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader if val_loader is not None else train_loader
        self.loss_fn = loss_fn
        self.cg_mod = cg_mod
        self.n_train = len(train_loader.dataset)
        shared = cfg.get("shared", {}) if isinstance(cfg, dict) else {}
        g_val = shared.get("g_val")
        ihvp = shared.get("ihvp_val")
        if g_val is None:
            g_val = _mean_grad_over_loader(model, self.val_loader, loss_fn, device)
        if ihvp is None:
            ihvp_res = cg_mod.inverse_hvp(g_val)
            ihvp = ihvp_res.get("ihvp")
        if ihvp is None:
            raise RuntimeError("CGInfluenceModule.inverse_hvp did not return 'ihvp'.")
        self.ihvp_val = ihvp.detach()

    def delta_theta_subset(self, S: np.ndarray) -> torch.Tensor:
        if self.cg_mod is None or self.loss_fn is None or self.train_loader is None:
            raise RuntimeError("IFBaseline not prepared.")
        S = np.asarray(S, dtype=np.int64)
        if S.size == 0:
            return torch.zeros_like(self.ihvp_val) if self.ihvp_val is not None else torch.tensor(0.0)
        # Compute g_S as mean gradient over indices S
        ds = self.train_loader.dataset
        sub = Subset(ds, S.tolist())
        ld = DataLoader(
            sub,
            batch_size=getattr(self.train_loader, "batch_size", 128) or 128,
            shuffle=False,
            num_workers=int(getattr(self.train_loader, "num_workers", 0) or 0),
            pin_memory=bool(getattr(self.train_loader, "pin_memory", False)),
        )
        model = self.cg_mod.model
        device = self.device or next(model.parameters()).device
        model.eval()
        model.to(device)

        g_sum: Optional[torch.Tensor] = None
        n = 0
        for xb, yb in ld:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            bsz = int(xb.shape[0])
            model.zero_grad(set_to_none=True)
            out = model(xb)
            loss = self.loss_fn(out, yb)
            loss.backward()
            g = _flatten_grads(model).detach()
            g_sum = (g * float(bsz)) if g_sum is None else (g_sum + g * float(bsz))
            n += bsz
        model.zero_grad(set_to_none=True)
        if g_sum is None or n == 0:
            raise RuntimeError("Empty subset loader in IFBaseline.delta_theta_subset")
        gS = (g_sum / float(n)).detach()

        sol = self.cg_mod.inverse_hvp(gS)
        ihvp = sol.get("ihvp") if isinstance(sol, dict) else None
        if ihvp is None:
            raise RuntimeError("CGInfluenceModule.inverse_hvp did not return 'ihvp' for subset")
        N = int(self.n_train)
        return (float(S.size) / float(max(1, N))) * ihvp.detach()

    def delta_theta_from_grad_mean(self, g_mean: torch.Tensor, subset_size: int) -> torch.Tensor:
        """Compute subset removal delta from a provided mean gradient.

        This is useful when the runner already computed g_S (e.g. for each fold).
        """
        if self.cg_mod is None:
            raise RuntimeError("IFBaseline not prepared.")
        sol = self.cg_mod.inverse_hvp(g_mean.detach())
        ihvp = sol.get("ihvp") if isinstance(sol, dict) else None
        if ihvp is None:
            raise RuntimeError("CGInfluenceModule.inverse_hvp did not return 'ihvp' for provided gradient")
        N = int(self.n_train)
        return (float(int(subset_size)) / float(max(1, N))) * ihvp.detach()

    def score_points(self) -> Optional[np.ndarray]:
        if self.train_loader is None or self.loss_fn is None or self.cg_mod is None or self.ihvp_val is None:
            raise RuntimeError("IFBaseline not prepared before scoring.")
        return _per_example_scores(
            model=self.cg_mod.model,
            loader=self.train_loader,
            loss_fn=self.loss_fn,
            direction=self.ihvp_val,
            total_train=self.n_train,
            device=self.device or next(self.cg_mod.model.parameters()).device,
        )


# -----------------------
# PBRF baseline
# -----------------------


@dataclass
class PBRFBaseline(Baseline):
    """Proximal Bregman Response Function (PBRF) baseline.

    Computes expensive retraining-based "ground truth" loss differences for a
    small set of training points, filling others with NaN.

    Minimal sanity check (recommended): set epsilon=0 and verify scores are
    near-zero (optimizer noise aside), since the objective is minimized at
    theta = theta^s.
    """

    name: str = "pbrf"
    lambda_damp: float = 1e-3
    retrain_epochs: int = 100
    lr: float = 0.01
    optimizer: str = "sgd"  # sgd|adam
    momentum: float = 0.9
    batch_size: int = 128
    num_workers: int = 0
    num_points: int = 20
    point_indices: Optional[List[int]] = None
    epsilon: Optional[float] = None  # default 1/N
    store_delta_theta: bool = False

    _scores: Optional[np.ndarray] = None
    _delta_thetas: Optional[Dict[int, torch.Tensor]] = None

    def __post_init__(self) -> None:
        self.device: Optional[torch.device] = None
        self.loss_fn: Optional[nn.Module] = None
        self.theta_s_state: Optional[Dict[str, torch.Tensor]] = None
        self.theta_s_params: Optional[List[torch.Tensor]] = None
        self.y_s_train: Optional[torch.Tensor] = None
        self.gy_s_train: Optional[torch.Tensor] = None
        self.L_s_train: Optional[torch.Tensor] = None
        self.n_train: int = 0
        self.query_loader: Optional[DataLoader] = None
        self._base_query_loss: Optional[float] = None

    def _make_optimizer(self, model: nn.Module) -> torch.optim.Optimizer:
        opt = str(self.optimizer).lower()
        if opt == "sgd":
            return torch.optim.SGD(model.parameters(), lr=float(self.lr), momentum=float(self.momentum))
        if opt == "adam":
            return torch.optim.Adam(model.parameters(), lr=float(self.lr))
        raise ValueError(f"Unknown optimizer={self.optimizer!r} (expected 'sgd' or 'adam')")

    def _cache_reference_tables(self, model: nn.Module, train_loader: DataLoader, device: torch.device) -> None:
        # Need deterministic reference predictions
        model.eval()
        model.to(device)

        indexed_loader = _clone_loader_with_index(train_loader.dataset, train_loader, shuffle=False)
        N = len(indexed_loader.dataset)

        y_s_cpu: Optional[torch.Tensor] = None
        gy_s_cpu: Optional[torch.Tensor] = None
        L_s_cpu = torch.empty((N,), dtype=torch.float32, device=torch.device("cpu"))

        for xb, yb, idxb in indexed_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            idxb_cpu = idxb.detach().cpu().long()

            with torch.no_grad():
                out = model(xb)
            # Allocate on first batch once we know output shape
            if y_s_cpu is None:
                out_shape = tuple(out.shape[1:])
                y_s_cpu = torch.empty((N,) + out_shape, dtype=torch.float32, device=torch.device("cpu"))
                gy_s_cpu = torch.empty((N,) + out_shape, dtype=torch.float32, device=torch.device("cpu"))

            y_s_cpu[idxb_cpu] = out.detach().float().cpu()

            # Compute gy_s and L_s via autograd on detached outputs
            out_ref = out.detach().requires_grad_(True)
            per_ex = _loss_per_example(self.loss_fn, out_ref, yb)
            # Gradient w.r.t outputs
            gy = torch.autograd.grad(per_ex.sum(), out_ref, create_graph=False, retain_graph=False)[0]
            gy_s_cpu[idxb_cpu] = gy.detach().float().cpu()
            L_s_cpu[idxb_cpu] = per_ex.detach().float().cpu()

        if y_s_cpu is None or gy_s_cpu is None:
            raise RuntimeError("Failed to build PBRF reference caches (empty train loader)")

        self.y_s_train = y_s_cpu
        self.gy_s_train = gy_s_cpu
        self.L_s_train = L_s_cpu

    def prepare(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        loss_fn: nn.Module,
        cg_mod: Any,
        caches: Dict[str, Dict[str, Any]],
        device: torch.device,
        cfg: Dict[str, Any],
        folds: Optional[List[np.ndarray]] = None,
    ) -> None:
        self.device = device
        self.loss_fn = loss_fn
        self.n_train = len(train_loader.dataset)
        if self.n_train <= 0:
            raise ValueError("PBRFBaseline requires non-empty train_loader.dataset")
        eps = float(self.epsilon) if self.epsilon is not None else (1.0 / float(self.n_train))

        # Save theta^s
        self.theta_s_state = copy.deepcopy(model.state_dict())
        # Also keep parameter snapshot on device for proximal term
        model.to(device)
        self.theta_s_params = [p.detach().clone() for p in model.parameters()]

        # Query loader defaults to val if provided else train
        self.query_loader = val_loader if val_loader is not None else train_loader
        self._base_query_loss = _mean_loss_over_loader(model, self.query_loader, loss_fn, device)

        # Cache reference outputs/gradients for all training points at theta^s
        self._cache_reference_tables(model, train_loader, device)

        # Decide probe indices
        if self.point_indices is not None:
            chosen = np.asarray(list(self.point_indices), dtype=np.int64)
        else:
            seed = int(cfg.get("seed", 0)) if isinstance(cfg, dict) else 0
            rng = np.random.RandomState(seed)
            k = min(int(self.num_points), int(self.n_train))
            chosen = rng.choice(self.n_train, size=k, replace=False).astype(np.int64)

        # Initialize score array with NaNs
        scores = np.full(self.n_train, np.nan, dtype=np.float64)
        delta_map: Dict[int, torch.Tensor] = {}

        # Build an indexed (x,y,idx) loader for retraining steps
        retrain_loader = DataLoader(
            _IndexedDataset(train_loader.dataset),
            batch_size=int(self.batch_size),
            shuffle=True,
            num_workers=int(self.num_workers),
            pin_memory=bool(getattr(train_loader, "pin_memory", False)),
            persistent_workers=bool(int(self.num_workers) > 0),
        )

        # Pre-fetch probe points
        ds_indexed = _IndexedDataset(train_loader.dataset)

        for z in chosen.tolist():
            # Reset to theta^s
            model.load_state_dict(self.theta_s_state, strict=True)
            model.to(device)
            model.train()
            opt = self._make_optimizer(model)

            x_z, t_z, _ = ds_indexed[int(z)]
            if not torch.is_tensor(x_z):
                x_z = torch.as_tensor(x_z)
            if not torch.is_tensor(t_z):
                t_z = torch.as_tensor(t_z)
            x_zb = x_z.unsqueeze(0).to(device)
            t_zb = t_z.unsqueeze(0).to(device)

            # Main PBRF warm retraining loop
            for _epoch in range(int(self.retrain_epochs)):
                for xb, yb, idxb in retrain_loader:
                    xb = xb.to(device, non_blocking=True)
                    yb = yb.to(device, non_blocking=True)
                    idxb_cpu = idxb.detach().cpu().long()

                    y = model(xb)

                    # Lookup reference tables (CPU) then move to device
                    y_s = self.y_s_train[idxb_cpu].to(device)
                    gy_s = self.gy_s_train[idxb_cpu].to(device)
                    L_s = self.L_s_train[idxb_cpu].to(device)

                    L_y = _loss_per_example(loss_fn, y, yb)
                    # <gy_s, y - y_s> per example
                    dot_term = (gy_s * (y - y_s)).view(int(y.shape[0]), -1).sum(dim=1)
                    D = L_y - L_s - dot_term
                    loss_breg = D.mean()

                    # Downweight/remove single point z
                    loss_z = loss_fn(model(x_zb), t_zb)

                    # Proximal term
                    prox_sum = torch.tensor(0.0, device=device)
                    if self.theta_s_params is None:
                        raise RuntimeError("Missing theta_s_params in PBRFBaseline")
                    for p, p0 in zip(model.parameters(), self.theta_s_params):
                        prox_sum = prox_sum + torch.sum((p - p0).pow(2))
                    prox = 0.5 * float(self.lambda_damp) * prox_sum

                    loss = loss_breg - float(eps) * loss_z + prox
                    opt.zero_grad(set_to_none=True)
                    loss.backward()
                    opt.step()

            # Evaluate query loss difference
            q_loss = _mean_loss_over_loader(model, self.query_loader, loss_fn, device)
            scores[int(z)] = float(q_loss - float(self._base_query_loss))

            if self.store_delta_theta:
                # Store flattened delta vector (on CPU)
                deltas: List[torch.Tensor] = []
                for p, p0 in zip(model.parameters(), self.theta_s_params):
                    deltas.append((p.detach() - p0).view(-1).cpu())
                delta_map[int(z)] = torch.cat(deltas, dim=0)

        self._scores = scores
        self._delta_thetas = delta_map if self.store_delta_theta else None

        # Restore model to theta^s
        model.load_state_dict(self.theta_s_state, strict=True)
        model.to(device)
        model.eval()

    def delta_theta_subset(self, S: np.ndarray) -> torch.Tensor:
        raise NotImplementedError("PBRFBaseline is a point-scoring baseline; use score_points().")

    def score_points(self) -> Optional[np.ndarray]:
        if self._scores is None:
            raise RuntimeError("PBRFBaseline not prepared.")
        return self._scores


@dataclass
class PBRFFoldBaseline:
    """Fold/subset PBRF (Option B).

    Fits a PBRF objective that approximates removal of a subset S and reports
    the resulting query loss change on S.

    Intended for running on a small set of folds (e.g. the same subset used for
    CV retraining baselines).
    """

    name: str = "pbrf_fold"
    lambda_damp: float = 1e-3
    retrain_epochs: int = 50
    lr: float = 1e-2
    optimizer: str = "sgd"  # sgd|adam
    momentum: float = 0.9
    batch_size: int = 128
    num_workers: int = 0
    epsilon: Optional[float] = None
    epsilon_mode: str = "m_over_N"  # m_over_N|one_over_N|fixed

    def __post_init__(self) -> None:
        self.device: Optional[torch.device] = None
        self.loss_fn: Optional[nn.Module] = None
        self.theta_s_state: Optional[Dict[str, torch.Tensor]] = None
        self.theta_s_params: Optional[List[torch.Tensor]] = None
        self.y_s_train: Optional[torch.Tensor] = None
        self.gy_s_train: Optional[torch.Tensor] = None
        self.L_s_train: Optional[torch.Tensor] = None
        self.n_train: int = 0
        self._retrain_loader: Optional[DataLoader] = None
        self._base_dataset: Optional[Dataset] = None

    def _make_optimizer(self, model: nn.Module) -> torch.optim.Optimizer:
        opt = str(self.optimizer).lower()
        if opt == "sgd":
            return torch.optim.SGD(model.parameters(), lr=float(self.lr), momentum=float(self.momentum))
        if opt == "adam":
            return torch.optim.Adam(model.parameters(), lr=float(self.lr))
        raise ValueError(f"Unknown optimizer={self.optimizer!r} (expected 'sgd' or 'adam')")

    def _cache_reference_tables(self, model: nn.Module, train_loader: DataLoader, device: torch.device) -> None:
        if self.loss_fn is None:
            raise RuntimeError("PBRFFoldBaseline requires loss_fn before caching")
        # Need deterministic reference predictions
        model.eval()
        model.to(device)

        indexed_loader = _clone_loader_with_index(train_loader.dataset, train_loader, shuffle=False)
        N = len(indexed_loader.dataset)

        y_s_cpu: Optional[torch.Tensor] = None
        gy_s_cpu: Optional[torch.Tensor] = None
        L_s_cpu = torch.empty((N,), dtype=torch.float32, device=torch.device("cpu"))

        for xb, yb, idxb in indexed_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            idxb_cpu = idxb.detach().cpu().long()

            with torch.no_grad():
                out = model(xb)
            # Allocate on first batch once we know output shape
            if y_s_cpu is None:
                out_shape = tuple(out.shape[1:])
                y_s_cpu = torch.empty((N,) + out_shape, dtype=torch.float32, device=torch.device("cpu"))
                gy_s_cpu = torch.empty((N,) + out_shape, dtype=torch.float32, device=torch.device("cpu"))

            y_s_cpu[idxb_cpu] = out.detach().float().cpu()

            # Compute gy_s and L_s via autograd on detached outputs
            out_ref = out.detach().requires_grad_(True)
            per_ex = _loss_per_example(self.loss_fn, out_ref, yb)
            gy = torch.autograd.grad(per_ex.sum(), out_ref, create_graph=False, retain_graph=False)[0]
            gy_s_cpu[idxb_cpu] = gy.detach().float().cpu()
            L_s_cpu[idxb_cpu] = per_ex.detach().float().cpu()

        if y_s_cpu is None or gy_s_cpu is None:
            raise RuntimeError("Failed to build PBRF reference caches (empty train loader)")

        self.y_s_train = y_s_cpu
        self.gy_s_train = gy_s_cpu
        self.L_s_train = L_s_cpu

    def prepare_shared(
        self,
        *,
        model: nn.Module,
        train_loader: DataLoader,
        loss_fn: nn.Module,
        device: torch.device,
    ) -> None:
        self.device = device
        self.loss_fn = loss_fn
        self.n_train = int(len(train_loader.dataset))
        if self.n_train <= 0:
            raise ValueError("PBRFFoldBaseline requires non-empty train_loader.dataset")
        self._base_dataset = train_loader.dataset

        # Save theta^s
        self.theta_s_state = copy.deepcopy(model.state_dict())
        model.to(device)
        self.theta_s_params = [p.detach().clone() for p in model.parameters()]

        # Cache reference outputs/gradients for all training points at theta^s
        self._cache_reference_tables(model, train_loader, device)

        # Indexed full-data loader for the Bregman term
        self._retrain_loader = DataLoader(
            _IndexedDataset(train_loader.dataset),
            batch_size=int(self.batch_size),
            shuffle=True,
            num_workers=int(self.num_workers),
            pin_memory=bool(getattr(train_loader, "pin_memory", False)),
            persistent_workers=bool(int(self.num_workers) > 0),
        )

        # Restore model to theta^s (defensive)
        model.load_state_dict(self.theta_s_state, strict=True)
        model.to(device)
        model.eval()

    def _epsilon_for_subset(self, m: int) -> float:
        if self.epsilon is not None:
            return float(self.epsilon)
        mode = str(self.epsilon_mode).lower()
        if mode == "m_over_n":
            return float(m) / float(max(1, self.n_train))
        if mode == "one_over_n":
            return 1.0 / float(max(1, self.n_train))
        if mode == "fixed":
            raise ValueError("epsilon_mode='fixed' requires epsilon to be provided")
        raise ValueError(f"Unknown epsilon_mode={self.epsilon_mode!r} (expected m_over_N|one_over_N|fixed)")

    def fit_subset(
        self,
        *,
        model: nn.Module,
        dataset: Dataset,
        subset_indices: np.ndarray,
    ) -> Dict[str, Any]:
        """Fit PBRF objective for removing/downweighting subset S.

        Leaves model at the fitted parameters. Call restore_theta_s() after evaluation.
        """

        if self.device is None or self.loss_fn is None:
            raise RuntimeError("PBRFFoldBaseline not prepared (missing device/loss_fn)")
        if self.theta_s_state is None or self.theta_s_params is None:
            raise RuntimeError("PBRFFoldBaseline not prepared (missing theta^s)")
        if self._retrain_loader is None:
            raise RuntimeError("PBRFFoldBaseline not prepared (missing retrain loader)")
        if self.y_s_train is None or self.gy_s_train is None or self.L_s_train is None:
            raise RuntimeError("PBRFFoldBaseline not prepared (missing reference caches)")

        device = self.device
        subset_indices = np.asarray(subset_indices, dtype=np.int64)
        m = int(subset_indices.size)
        if m <= 0:
            return {"eps": 0.0, "subset_size": 0}
        eps = self._epsilon_for_subset(m)

        # Reset to theta^s
        model.load_state_dict(self.theta_s_state, strict=True)
        model.to(device)
        model.train()

        # Subset loader (x,y,idx) where idx refers to the base dataset index
        subset_loader = DataLoader(
            Subset(_IndexedDataset(dataset), subset_indices.tolist()),
            batch_size=int(self.batch_size),
            shuffle=True,
            num_workers=int(self.num_workers),
            pin_memory=bool(getattr(self._retrain_loader, "pin_memory", False)),
            persistent_workers=bool(int(self.num_workers) > 0),
        )
        subset_iter = iter(subset_loader)

        opt = self._make_optimizer(model)

        # Main PBRF training loop
        for _epoch in range(int(self.retrain_epochs)):
            for xb, yb, idxb in self._retrain_loader:
                try:
                    xsb, ysb, _ = next(subset_iter)
                except StopIteration:
                    subset_iter = iter(subset_loader)
                    xsb, ysb, _ = next(subset_iter)

                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)
                idxb_cpu = idxb.detach().cpu().long()

                xsb = xsb.to(device, non_blocking=True)
                ysb = ysb.to(device, non_blocking=True)

                y = model(xb)

                # Lookup reference tables (CPU) then move to device
                y_s = self.y_s_train[idxb_cpu].to(device)
                gy_s = self.gy_s_train[idxb_cpu].to(device)
                L_s = self.L_s_train[idxb_cpu].to(device)

                L_y = _loss_per_example(self.loss_fn, y, yb)
                dot_term = (gy_s * (y - y_s)).view(int(y.shape[0]), -1).sum(dim=1)
                D = L_y - L_s - dot_term
                loss_breg = D.mean()

                # Subset loss term (mean loss over a minibatch from S)
                yS = model(xsb)
                loss_S = _loss_per_example(self.loss_fn, yS, ysb).mean()

                # Proximal term
                prox_sum = torch.tensor(0.0, device=device)
                for p, p0 in zip(model.parameters(), self.theta_s_params):
                    prox_sum = prox_sum + torch.sum((p - p0).pow(2))
                prox = 0.5 * float(self.lambda_damp) * prox_sum

                loss = loss_breg - float(eps) * loss_S + prox
                opt.zero_grad(set_to_none=True)
                loss.backward()
                opt.step()

        return {"eps": float(eps), "subset_size": int(m)}

    def restore_theta_s(self, *, model: nn.Module) -> None:
        if self.theta_s_state is None or self.device is None:
            raise RuntimeError("PBRFFoldBaseline not prepared")
        model.load_state_dict(self.theta_s_state, strict=True)
        model.to(self.device)
        model.eval()


__all__ = [
    "Baseline",
    "IFOneSolveBaseline",
    "TRAKBaseline",
    "IFCPointBaseline",
    "IFBaseline",
    "PBRFBaseline",
    "PBRFFoldBaseline",
    "RandomBaseline",
    "BestBaseline"
]
