from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import copy
import hashlib
import json
import logging
import math
import time
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset

from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, average_precision_score, roc_auc_score
from sklearn.model_selection import train_test_split

from ..evals.metrics import pearson_corr, spearman_corr
from ..evals.poison_detection import poison_detection_report_v2
from ..estimators.baselines import _mean_grad_over_loader, TRAKBaseline
from ..estimators.if_compressed import (
    IFCompressedCfg,
    build_ifc_cache,
    delta_ifc_from_cache,
)
from ..estimators.helpers import CGInfluenceModule, BaseObjective, eval_on_indices, flatten_params
from ..utils.seed import set_seed
from ..utils.train import train_model
from ..utils.logging import setup_logging

from .run_ablation import (
    _extract_labels,
    _compute_dataset_stats,
    _flatten_model_params,
    _accuracy_metric,
    _cosine_torch,
    _angle_from_cos,
    _stream_grad_diagnostics,
    _log_structured,
    _serialize,
)


class _PaperObjective(BaseObjective):
    """Thin wrapper so we can reuse CGInfluenceModule."""

    def __init__(self, loss_fn: nn.Module, weight_decay: float = 0.0):
        self.loss_fn = loss_fn
        self.weight_decay = 0.0

    def train_outputs(self, model: nn.Module, batch):
        x, _ = batch
        return model(x)

    def train_loss_on_outputs(self, outputs: torch.Tensor, batch):
        _, y = batch
        return self.loss_fn(outputs, y)

    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        return self.weight_decay * torch.sum(params ** 2)

    def test_loss(self, model: nn.Module, params: torch.Tensor, batch):
        x, y = batch
        return self.loss_fn(model(x), y)


class LabelOverrideDataset(Dataset):
    """Dataset wrapper that replaces labels with a provided array."""

    def __init__(self, base: Dataset, labels: np.ndarray):
        self.base = base
        self.labels = np.asarray(labels).copy()
        self.targets = self.labels

    def __len__(self) -> int:
        return len(self.base)

    def __getitem__(self, idx: int):
        x, y = self.base[idx]
        label = self.labels[idx]
        if torch.is_tensor(y):
            label = torch.as_tensor(label, dtype=y.dtype)
        elif isinstance(y, np.generic):
            label = y.__class__(label)
        else:
            try:
                label = y.__class__(label)
            except Exception:
                label = int(label)
        return x, label


def _mix64(z: int) -> int:
    z &= (1 << 64) - 1
    z = (z + 0x9E3779B97F4A7C15) & ((1 << 64) - 1)
    z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9 & ((1 << 64) - 1)
    z = (z ^ (z >> 27)) * 0x94D049BB133111EB & ((1 << 64) - 1)
    z = z ^ (z >> 31)
    return z & ((1 << 64) - 1)

def _u01(seed: int, idx: int, salt: int) -> float:
    # deterministic float in [0,1)
    x = _mix64(int(seed) ^ _mix64(int(idx) + 0xD1B54A32D192ED03) ^ _mix64(int(salt)))
    return ((x >> 11) & ((1 << 53) - 1)) / float(1 << 53)

def _sample_range(v, seed: int, idx: int, salt: int, *, as_int: bool = False):
    if isinstance(v, (list, tuple)) and len(v) == 2:
        a, b = float(v[0]), float(v[1])
        lo, hi = (a, b) if a <= b else (b, a)
        u = _u01(seed, idx, salt)
        out = lo + (hi - lo) * u
        return int(round(out)) if as_int else float(out)
    return int(v) if as_int else float(v)

def apply_patch_trigger(x: Any, trigger_spec: Dict[str, Any], idx: Optional[int] = None) -> Any:
    """Apply a patch trigger. Supports deterministic randomization using (seed, idx)."""
    if trigger_spec is None:
        return x
    if str(trigger_spec.get("type", "patch")).lower() != "patch":
        raise ValueError(f"Unsupported trigger type: {trigger_spec.get('type')}")

    idx = int(idx) if idx is not None else 0
    seed = int(trigger_spec.get("seed", 0))

    size = _sample_range(trigger_spec.get("size", 4), seed, idx, 1, as_int=True)
    if size <= 0:
        return x

    loc_spec = trigger_spec.get("location", "bottom_right")  # keep as object (may be list)
    value = _sample_range(trigger_spec.get("value", 1.0), seed, idx, 2, as_int=False)
    blend = _sample_range(trigger_spec.get("blend", 1.0), seed, idx, 3, as_int=False)
    blend = max(0.0, min(1.0, float(blend)))
    channels = trigger_spec.get("channels", "all")

    is_torch = torch.is_tensor(x)
    arr = x.clone() if is_torch else np.array(x, copy=True)

    # --- shape inference (keep your existing logic) ---
    if is_torch:
        if arr.ndim not in (2, 3):
            return x
        if arr.ndim == 3 and arr.shape[0] in (1, 3) and arr.shape[1] >= size and arr.shape[2] >= size:
            chw = True
            C, H, W = int(arr.shape[0]), int(arr.shape[1]), int(arr.shape[2])
        elif arr.ndim == 3 and arr.shape[-1] in (1, 3) and arr.shape[0] >= size and arr.shape[1] >= size:
            chw = False
            H, W, C = int(arr.shape[0]), int(arr.shape[1]), int(arr.shape[2])
        else:
            chw = False
            if arr.ndim == 2:
                H, W = int(arr.shape[0]), int(arr.shape[1])
                C = 1
            else:
                H, W, C = int(arr.shape[0]), int(arr.shape[1]), int(arr.shape[2])
        if arr.ndim == 2:
            H, W = int(arr.shape[0]), int(arr.shape[1])
    else:
        if arr.ndim not in (2, 3):
            return x
        if arr.ndim == 3 and arr.shape[0] in (1, 3) and arr.shape[1] >= size and arr.shape[2] >= size:
            chw = True
            C, H, W = int(arr.shape[0]), int(arr.shape[1]), int(arr.shape[2])
        elif arr.ndim == 3 and arr.shape[-1] in (1, 3) and arr.shape[0] >= size and arr.shape[1] >= size:
            chw = False
            H, W, C = int(arr.shape[0]), int(arr.shape[1]), int(arr.shape[2])
        else:
            chw = False
            if arr.ndim == 2:
                H, W = int(arr.shape[0]), int(arr.shape[1])
                C = 1
            else:
                H, W, C = int(arr.shape[0]), int(arr.shape[1]), int(arr.shape[2])

    if H < size or W < size:
        return x

    # --- location sampling (deterministic) ---
    def _resolve_location(spec):
        if isinstance(spec, (list, tuple)) and len(spec) > 0:
            j = int(_u01(seed, idx, 10) * len(spec))
            j = max(0, min(j, len(spec) - 1))
            return str(spec[j]).lower()
        return str(spec).lower()

    location = _resolve_location(loc_spec)

    def _coords(Hh: int, Ww: int) -> Tuple[int, int]:
        if location == "random":
            rt = int(_u01(seed, idx, 11) * max(1, (Hh - size + 1)))
            ct = int(_u01(seed, idx, 12) * max(1, (Ww - size + 1)))
            return rt, ct
        if location == "top_left":
            return 0, 0
        if location == "top_right":
            return 0, Ww - size
        if location == "bottom_left":
            return Hh - size, 0
        if location == "center":
            return (Hh - size) // 2, (Ww - size) // 2
        return Hh - size, Ww - size  # bottom_right default

    r0, c0 = _coords(H, W)
    r0 = max(0, min(r0, H - size))
    c0 = max(0, min(c0, W - size))
    r1, c1 = r0 + size, c0 + size

    if channels == "all" or channels is None:
        ch_idx = list(range(int(C)))
    elif isinstance(channels, (list, tuple)):
        ch_idx = [int(ci) for ci in channels]
    else:
        try:
            ch_idx = [int(channels)]
        except Exception:
            ch_idx = list(range(int(C)))

    if arr.ndim == 2:
        patch = arr[r0:r1, c0:c1]
        patch[...] = value if blend >= 1.0 else (1.0 - blend) * patch + blend * value
        arr[r0:r1, c0:c1] = patch
    else:
        if chw:
            for ci in ch_idx:
                if 0 <= ci < int(arr.shape[0]):
                    patch = arr[ci, r0:r1, c0:c1]
                    patch[...] = value if blend >= 1.0 else (1.0 - blend) * patch + blend * value
                    arr[ci, r0:r1, c0:c1] = patch
        else:
            for ci in ch_idx:
                if 0 <= ci < int(arr.shape[-1]):
                    patch = arr[r0:r1, c0:c1, ci]
                    patch[...] = value if blend >= 1.0 else (1.0 - blend) * patch + blend * value
                    arr[r0:r1, c0:c1, ci] = patch

    return arr


class ConstantLabelDataset(Dataset):
    """Wrap a dataset but replace every label with a constant value."""

    def __init__(self, base: Dataset, label: int):
        self.base = base
        self.label = int(label)

    def __len__(self) -> int:
        return len(self.base)

    def __getitem__(self, idx: int):
        x, y = self.base[idx]
        if torch.is_tensor(y):
            yy = torch.as_tensor(self.label, dtype=y.dtype)
        elif isinstance(y, np.generic):
            yy = y.__class__(self.label)
        else:
            try:
                yy = y.__class__(self.label)
            except Exception:
                yy = int(self.label)
        return x, yy


class TriggerAllDataset(Dataset):
    """Wrap a dataset and apply a trigger to all inputs."""

    def __init__(self, base: Dataset, trigger_spec: Dict[str, Any]):
        self.base = base
        self.trigger_spec = trigger_spec
        self.cache = {}

    def __len__(self) -> int:
        return len(self.base)

    def __getitem__(self, idx: int):
        if idx in self.cache:
            return self.cache[idx]
        x, y = self.base[idx]
        x = apply_patch_trigger(x, self.trigger_spec, idx=idx)
        self.cache[idx] = (x, y)
        return x, y

class BackdoorTriggerDataset(Dataset):
    """Wrap a dataset and apply a trigger to a subset of indices."""

    def __init__(self, base: Dataset, trigger_indices: Set[int], trigger_spec: Dict[str, Any]):
        self.base = base
        self.trigger_indices = set(trigger_indices)
        self.trigger_spec = trigger_spec
        self.cache = {}

    def __len__(self) -> int:
        return len(self.base)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
        x, y = self.base[idx]
        if idx in self.trigger_indices:
            x = apply_patch_trigger(x, self.trigger_spec, idx=idx)
        self.cache[idx] = (x, y)
        return x, y



def make_index_groups(n: int, G: int, grouping: str, rng: np.random.Generator) -> List[np.ndarray]:
    """Group indices [0..n-1] into G groups using a simple strategy."""
    n = int(n)
    if n <= 0:
        return []
    G = int(max(1, int(G)))
    grouping = str(grouping or "random").lower()
    idx = np.arange(n, dtype=np.int64)
    if grouping == "random":
        order = rng.permutation(n)
        idx = idx[order]
    elif grouping == "contiguous":
        # already sorted
        pass
    elif grouping == "hash":
        keys = (idx.astype(np.int64) * np.int64(2654435761)) % np.int64(2**31 - 1)
        idx = idx[np.argsort(keys)]
    else:
        raise ValueError(f"Unknown grouping strategy: {grouping}")
    chunks = np.array_split(idx, G)
    return [np.asarray(c, dtype=np.int64) for c in chunks]


@dataclass
class ClusterMethodSpec:
    name: str
    method: str
    cfg: IFCompressedCfg
    description: str = ""


def _select_poison_indices(labels: np.ndarray, poison_cfg: Dict[str, Any], rng: np.random.Generator) -> np.ndarray:
    frac = float(poison_cfg.get("fraction", 0.0))
    if frac <= 0:
        return np.empty(0, dtype=np.int64)
    n = len(labels)
    k = max(1, int(round(frac * n)))
    strategy = poison_cfg.get("strategy", "random")
    if strategy == "targeted":
        target = poison_cfg.get("target_label", None)
        mask = labels == target if target is not None else np.ones_like(labels, dtype=bool)
        candidates = np.where(mask)[0]
        if candidates.size == 0:
            return np.empty(0, dtype=np.int64)
        idx = rng.choice(candidates, size=min(k, candidates.size), replace=False)
        return np.sort(idx)
    idx = rng.choice(n, size=min(k, n), replace=False)
    return np.sort(idx)


def _select_poison_indices_from_pool(
    labels: np.ndarray,
    pool_indices: np.ndarray,
    poison_cfg: Dict[str, Any],
    rng: np.random.Generator,
) -> np.ndarray:
    """Select poison indices from a provided pool of original dataset indices."""
    pool_indices = np.asarray(pool_indices, dtype=np.int64)
    if pool_indices.size == 0:
        return np.empty(0, dtype=np.int64)
    rel = _select_poison_indices(np.asarray(labels, dtype=np.int64)[pool_indices], poison_cfg, rng)
    if rel.size == 0:
        return np.empty(0, dtype=np.int64)
    return np.sort(pool_indices[rel])


def make_groups(
    train_idx: np.ndarray,
    labels_clean: Optional[np.ndarray],
    G: int,
    grouping: str,
    rng: np.random.Generator,
) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]:
    """Partition training indices into disjoint groups.

    Returns:
      - groups_orig[g]: original dataset indices
      - groups_train_local[g]: local indices within `Subset(..., train_idx)` (0..n_train-1)
      - group_id_local[i]: group id for each local train point
    """
    train_idx = np.asarray(train_idx, dtype=np.int64)
    n_train = int(train_idx.size)
    if n_train == 0:
        return [], [], np.empty(0, dtype=np.int64)
    G = int(G)
    G = max(1, min(G, n_train))
    grouping = str(grouping or "random").lower()

    if grouping == "random":
        order = rng.permutation(n_train)
        ordered_train_idx = train_idx[order]
    elif grouping == "contiguous":
        ordered_train_idx = np.sort(train_idx)
    elif grouping == "hash":
        # Simple deterministic-ish hash grouping, then keep groups roughly balanced.
        # Note: this is not cryptographic; it's only for grouping variety.
        keys = (train_idx.astype(np.int64) * np.int64(2654435761)) % np.int64(2**31 - 1)
        order = np.argsort(keys)
        ordered_train_idx = train_idx[order]
    else:
        raise ValueError(f"Unknown grouping strategy: {grouping}")

    # Chunk into G nearly-equal groups.
    chunks = np.array_split(ordered_train_idx, G)
    groups_orig: List[np.ndarray] = [np.asarray(c, dtype=np.int64) for c in chunks]

    # Map original -> local index within `train_idx.tolist()` (which preserves original order).
    train_index_lookup = {int(orig): int(i) for i, orig in enumerate(train_idx.tolist())}
    groups_train_local: List[np.ndarray] = []
    group_id_local = -np.ones((n_train,), dtype=np.int64)
    for g, grp in enumerate(groups_orig):
        local = [train_index_lookup[int(ii)] for ii in grp.tolist() if int(ii) in train_index_lookup]
        local_arr = np.asarray(local, dtype=np.int64)
        groups_train_local.append(local_arr)
        group_id_local[local_arr] = int(g)

    # Acceptance: disjoint + cover exactly once.
    if int((group_id_local >= 0).sum()) != n_train:
        raise RuntimeError("Group partition does not cover all training points exactly once.")
    if int(np.unique(group_id_local).size) != int(G):
        # This can happen if G > n_train, but we cap above; otherwise it's a bug.
        raise RuntimeError("Group partition produced missing/empty groups unexpectedly.")
    return groups_orig, groups_train_local, group_id_local


def select_poison_groups(
    groups_orig: List[np.ndarray],
    labels_clean: Optional[np.ndarray],
    poison_cfg: Dict[str, Any],
    rng: np.random.Generator,
) -> List[int]:
    """Select group ids to poison."""
    frac = float(poison_cfg.get("fraction_groups", poison_cfg.get("fraction", 0.0)))
    if frac <= 0.0 or not groups_orig:
        return []
    G = len(groups_orig)
    k = max(1, int(round(frac * float(G))))
    k = min(k, G)

    strategy = str(poison_cfg.get("strategy", "random")).lower()
    if strategy == "targeted_majority_label" and labels_clean is not None:
        target = poison_cfg.get("target_label", None)
        candidates: List[int] = []
        for g, idxs in enumerate(groups_orig):
            if idxs.size == 0:
                continue
            labs = np.asarray(labels_clean, dtype=np.int64)[idxs]
            if labs.size == 0:
                continue
            maj = int(np.bincount(labs).argmax())
            if target is None or maj == int(target):
                candidates.append(int(g))
        if not candidates:
            return []
        choose_from = np.asarray(candidates, dtype=np.int64)
        picked = rng.choice(choose_from, size=min(k, choose_from.size), replace=False)
        return sorted(int(x) for x in picked.tolist())

    picked = rng.choice(np.arange(G, dtype=np.int64), size=k, replace=False)
    return sorted(int(x) for x in picked.tolist())


def _apply_group_label_mutations_train_only(
    *,
    dataset: Dataset,
    base_labels: Optional[np.ndarray],
    train_idx: np.ndarray,
    groups_orig: List[np.ndarray],
    poisoned_group_ids: List[int],
    poison_cfg: Dict[str, Any],
    rng: np.random.Generator,
) -> Tuple[Dataset, Optional[np.ndarray], Dict[str, Any]]:
    """Apply label corruption only inside selected groups, and only on train indices."""
    meta: Dict[str, Any] = {
        "poisoned_fraction": 0.0,
        "poisoned_fraction_train": 0.0,
        "poisoned_group_ids": list(int(x) for x in poisoned_group_ids),
        "poisoned_indices": [],
        "poison_strategy": str(poison_cfg.get("strategy", "random")),
        "poison_label": poison_cfg.get("poison_label", None),
        "within_group_fraction": float(poison_cfg.get("within_group_fraction", 1.0)),
    }
    if base_labels is None or not poisoned_group_ids:
        return dataset, base_labels, meta

    labels = np.asarray(base_labels, dtype=np.int64).copy()
    train_idx = np.asarray(train_idx, dtype=np.int64)
    train_set = set(int(i) for i in train_idx.tolist())

    within = float(poison_cfg.get("within_group_fraction", 1.0))
    within = max(0.0, min(within, 1.0))
    all_poisoned: List[int] = []
    for g in poisoned_group_ids:
        if g < 0 or g >= len(groups_orig):
            continue
        idxs = np.asarray(groups_orig[g], dtype=np.int64)
        # Safety: groups are already subsets of train_idx, but keep the guard.
        idxs = np.asarray([ii for ii in idxs.tolist() if int(ii) in train_set], dtype=np.int64)
        if idxs.size == 0:
            continue
        if within < 1.0:
            m = max(1, int(round(within * float(idxs.size))))
            idxs = rng.choice(idxs, size=min(m, idxs.size), replace=False)
            idxs = np.sort(np.asarray(idxs, dtype=np.int64))

        new_label = poison_cfg.get("poison_label", None)
        if new_label is None:
            # Random wrong label per point, sampling from labels present in the training split.
            unique = np.unique(labels[train_idx])
            for ii in idxs.tolist():
                choices = unique[unique != labels[int(ii)]]
                labels[int(ii)] = int(rng.choice(choices)) if choices.size else int(labels[int(ii)])
        else:
            labels[idxs] = int(new_label)

        all_poisoned.extend(int(x) for x in idxs.tolist())

    all_poisoned = sorted(set(all_poisoned))
    meta["poisoned_indices"] = all_poisoned
    meta["poisoned_fraction"] = float(len(all_poisoned) / max(1, labels.size))
    meta["poisoned_fraction_train"] = float(len(all_poisoned) / max(1, train_idx.size))

    # If nothing changed, skip wrapping.
    if np.array_equal(labels, np.asarray(base_labels, dtype=np.int64)):
        return dataset, np.asarray(base_labels, dtype=np.int64), meta
    return LabelOverrideDataset(dataset, labels), labels, meta


def _apply_group_backdoor_train_only(
    *,
    dataset: Dataset,
    base_labels: Optional[np.ndarray],
    train_idx: np.ndarray,
    groups_orig: List[np.ndarray],
    poisoned_group_ids: List[int],
    backdoor_cfg: Dict[str, Any],
    rng: np.random.Generator,
) -> Tuple[Dataset, Optional[np.ndarray], Dict[str, Any]]:
    """Apply a simple backdoor to selected groups: trigger inputs + set labels to target label."""
    target_label = backdoor_cfg.get("target_label", None)
    trigger_spec = backdoor_cfg.get("trigger", None)
    mode = str(backdoor_cfg.get("mode", "fixedV"))
    within = float(backdoor_cfg.get("within_group_fraction", 1.0))
    within = max(0.0, min(within, 1.0))

    meta: Dict[str, Any] = {
        "poisoned_fraction": 0.0,
        "poisoned_fraction_train": 0.0,
        "poisoned_group_ids": list(int(x) for x in poisoned_group_ids),
        "poisoned_indices": [],
        "poison_strategy": "backdoor",
        "backdoor": {
            "enabled": True,
            "mode": mode,
            "target_label": int(target_label) if target_label is not None else None,
            "trigger": trigger_spec,
            "within_group_fraction": within,
        },
    }

    if base_labels is None:
        raise ValueError("Backdoor requested but base labels are unavailable.")
    if target_label is None:
        raise ValueError("Backdoor requested but backdoor.target_label is missing.")
    if trigger_spec is None:
        raise ValueError("Backdoor requested but backdoor.trigger is missing.")
    if not poisoned_group_ids:
        return dataset, base_labels, meta

    labels = np.asarray(base_labels, dtype=np.int64).copy()
    train_idx = np.asarray(train_idx, dtype=np.int64)
    train_set = set(int(i) for i in train_idx.tolist())

    poisoned_indices: List[int] = []
    for g in poisoned_group_ids:
        if g < 0 or g >= len(groups_orig):
            continue
        idxs = np.asarray(groups_orig[g], dtype=np.int64)
        idxs = np.asarray([ii for ii in idxs.tolist() if int(ii) in train_set], dtype=np.int64)
        if idxs.size == 0:
            continue
        if within < 1.0:
            m = max(1, int(round(within * float(idxs.size))))
            idxs = rng.choice(idxs, size=min(m, idxs.size), replace=False)
            idxs = np.sort(np.asarray(idxs, dtype=np.int64))
        poisoned_indices.extend(int(x) for x in idxs.tolist())

    poisoned_indices = sorted(set(poisoned_indices))
    if not poisoned_indices:
        return dataset, np.asarray(base_labels, dtype=np.int64), meta

    labels[np.asarray(poisoned_indices, dtype=np.int64)] = int(target_label)

    meta["poisoned_indices"] = poisoned_indices
    meta["poisoned_fraction"] = float(len(poisoned_indices) / max(1, labels.size))
    meta["poisoned_fraction_train"] = float(len(poisoned_indices) / max(1, train_idx.size))

    ds_labeled: Dataset = LabelOverrideDataset(dataset, labels)
    ds_triggered: Dataset = BackdoorTriggerDataset(ds_labeled, set(poisoned_indices), trigger_spec)
    return ds_triggered, labels, meta


@torch.no_grad()
def _eval_loss_acc_over_loader(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    total = 0
    correct = 0
    loss_sum = 0.0
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        out = model(xb)
        loss = loss_fn(out, yb)
        bsz = int(xb.shape[0])
        total += bsz
        loss_sum += float(loss.detach().float().item()) * float(bsz)
        try:
            pred = out.detach().argmax(dim=1)
            correct += int((pred == yb).sum().item())
        except Exception:
            pass
    if total <= 0:
        return {"loss": float("nan"), "accuracy": float("nan")}
    return {"loss": float(loss_sum / float(total)), "accuracy": float(correct) / float(total)}


def _apply_label_mutations_train_only(
    *,
    dataset: Dataset,
    base_labels: Optional[np.ndarray],
    scenario_cfg: Dict[str, Any],
    rng: np.random.Generator,
    train_idx: np.ndarray,
) -> Tuple[Dataset, Optional[np.ndarray], Dict[str, Any]]:
    """Apply relabeling/poisoning only on the training indices.

    Returns a dataset wrapper (for use with Subset(train_idx)), the full modified
    labels array (length == len(dataset)), and metadata.
    """
    meta: Dict[str, Any] = {
        "poisoned_fraction": 0.0,
        "poisoned_fraction_train": 0.0,
        "poisoned_indices": [],
        "relabel_map": scenario_cfg.get("relabel_map", {}),
    }
    if base_labels is None:
        return dataset, None, meta

    labels = np.asarray(base_labels, dtype=np.int64).copy()
    train_idx = np.asarray(train_idx, dtype=np.int64)

    relabel_map = scenario_cfg.get("relabel_map", None)
    if relabel_map and train_idx.size:
        train_labels = labels[train_idx]
        for src, dst in relabel_map.items():
            train_labels[train_labels == int(src)] = int(dst)
        labels[train_idx] = train_labels

    poison_cfg = scenario_cfg.get("poison", None)
    if poison_cfg and train_idx.size:
        idx = _select_poison_indices_from_pool(labels, train_idx, poison_cfg, rng)
        if idx.size:
            new_label = poison_cfg.get("poison_label")
            if new_label is None:
                unique = np.unique(labels[train_idx])
                for i in idx.tolist():
                    choices = unique[unique != labels[i]]
                    labels[i] = int(rng.choice(choices)) if choices.size else int(labels[i])
            else:
                labels[idx] = int(new_label)
            meta["poisoned_fraction"] = float(idx.size / max(1, len(labels)))
            meta["poisoned_fraction_train"] = float(idx.size / max(1, train_idx.size))
            meta["poisoned_indices"] = [int(i) for i in idx.tolist()]
            meta["poison_label"] = new_label
            meta["poison_strategy"] = poison_cfg.get("strategy", "random")

    if np.array_equal(labels, np.asarray(base_labels, dtype=np.int64)):
        return dataset, np.asarray(base_labels, dtype=np.int64), meta
    wrapped = LabelOverrideDataset(dataset, labels)
    return wrapped, labels, meta



def _cluster_alignment_metrics(cluster_labels: np.ndarray, true_labels: Optional[np.ndarray]) -> Dict[str, Any]:
    if true_labels is None or true_labels.shape[0] != cluster_labels.shape[0]:
        return {
            "purity": float("nan"),
            "nmi": float("nan"),
            "ari": float("nan"),
            "majority_map": {},
        }
    purity = 0.0
    majority_map: Dict[int, Dict[str, Any]] = {}
    total = len(cluster_labels)
    for c in np.unique(cluster_labels):
        mask = cluster_labels == c
        votes = np.bincount(true_labels[mask]) if mask.any() else np.array([0])
        top = int(np.argmax(votes)) if votes.size else -1
        cnt = int(votes[top]) if votes.size else 0
        purity += cnt
        majority_map[int(c)] = {
            "label": top,
            "count": cnt,
            "fraction": float(cnt / max(1, mask.sum())),
        }
    purity = purity / max(1, total)
    try:
        nmi = float(normalized_mutual_info_score(true_labels, cluster_labels))
    except Exception:
        nmi = float("nan")
    try:
        ari = float(adjusted_rand_score(true_labels, cluster_labels))
    except Exception:
        ari = float("nan")
    return {
        "purity": float(purity),
        "nmi": nmi,
        "ari": ari,
        "majority_map": majority_map,
    }



def select_indices_from_scores(scores: np.ndarray, b: int, *, higher_is_better: bool = True) -> np.ndarray:
    """Select b indices from scores, ignoring NaNs/Infs.

    If higher_is_better=True selects largest scores; else selects smallest.
    """
    s = np.asarray(scores, dtype=float)
    ok = np.isfinite(s)
    if int(ok.sum()) == 0:
        return np.empty(0, dtype=np.int64)
    idx = np.where(ok)[0]
    s_ok = s[ok]
    order = np.argsort(-s_ok) if higher_is_better else np.argsort(s_ok)
    b = min(int(b), int(order.size))
    return idx[order[:b]].astype(np.int64)

def _metric_improvement(metric_name: str, base_value: float, new_value: float) -> float:
    """Return improvement where positive means better."""
    if str(metric_name).lower() in {"loss", "val_loss", "nll"}:
        return float(base_value) - float(new_value)
    return float(new_value) - float(base_value)



def _split_train_val_indices(
    n_samples: int,
    labels: Optional[np.ndarray],
    val_fraction: float,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray]:
    val_fraction = float(val_fraction)
    if val_fraction <= 0.0 or n_samples <= 1:
        idx = np.arange(n_samples, dtype=np.int64)
        return idx, np.empty(0, dtype=np.int64)
    val_fraction = min(val_fraction, 0.5)
    stratify = labels if labels is not None else None
    train_idx, val_idx = train_test_split(
        np.arange(n_samples, dtype=np.int64),
        test_size=val_fraction,
        stratify=stratify,
        random_state=seed,
        shuffle=True,
    )
    train_idx = np.asarray(train_idx, dtype=np.int64)
    val_idx = np.asarray(val_idx, dtype=np.int64)
    train_idx.sort()
    val_idx.sort()
    return train_idx, val_idx


def _split_train_val_test_indices(
    n_samples: int,
    labels: Optional[np.ndarray],
    val_fraction: float,
    test_fraction: float,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Return disjoint train/val/test indices.

    Splits off the test set first, then val from the remaining pool.
    Fractions are w.r.t. the full dataset.
    """
    val_fraction = float(val_fraction)
    test_fraction = float(test_fraction)

    if n_samples <= 1:
        idx = np.arange(n_samples, dtype=np.int64)
        return idx, np.empty(0, dtype=np.int64), np.empty(0, dtype=np.int64)

    val_fraction = max(0.0, min(val_fraction, 0.5))
    test_fraction = max(0.0, min(test_fraction, 0.9))

    all_idx = np.arange(n_samples, dtype=np.int64)
    stratify = labels if labels is not None else None

    if test_fraction > 0.0:
        train_tmp, test_idx = train_test_split(
            all_idx,
            test_size=test_fraction,
            stratify=stratify,
            random_state=seed,
            shuffle=True,
        )
        train_tmp = np.asarray(train_tmp, dtype=np.int64)
        test_idx = np.asarray(test_idx, dtype=np.int64)
        test_idx.sort()
        stratify_tmp = np.asarray(labels)[train_tmp] if labels is not None else None
    else:
        train_tmp = all_idx
        test_idx = np.empty(0, dtype=np.int64)
        stratify_tmp = stratify

    if val_fraction > 0.0 and train_tmp.size > 1:
        rel = min(0.5, val_fraction / float(max(1e-12, 1.0 - test_fraction)))
        train_idx, val_idx = train_test_split(
            train_tmp,
            test_size=rel,
            stratify=stratify_tmp,
            random_state=seed + 1,
            shuffle=True,
        )
        train_idx = np.asarray(train_idx, dtype=np.int64)
        val_idx = np.asarray(val_idx, dtype=np.int64)
        train_idx.sort()
        val_idx.sort()
    else:
        train_idx = np.asarray(train_tmp, dtype=np.int64)
        train_idx.sort()
        val_idx = np.empty(0, dtype=np.int64)

    return train_idx, val_idx, test_idx


def _cache_dataset_to_device(
    dataset: Dataset,
    batch_size: int,
    device: torch.device,
    num_workers: int,
) -> torch.utils.data.TensorDataset:
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=False,
    )
    xs, ys = [], []
    for xb, yb in loader:
        xs.append(xb.to(device, non_blocking=True))
        ys.append(yb.to(device, non_blocking=True))
    return torch.utils.data.TensorDataset(torch.cat(xs, dim=0), torch.cat(ys, dim=0))



def _poison_detection_report_v2(
    scores: np.ndarray,
    poison_mask: np.ndarray,
    topk_list: List[int],
    *,
    candidate_only: bool = False,
    nonfinite_margin: float = 1.0,
) -> Dict[str, Any]:
    return poison_detection_report_v2(
        scores=scores,
        poison_mask=poison_mask,
        topk_list=topk_list,
        candidate_only=candidate_only,
        nonfinite_margin=nonfinite_margin,
    )



def _method_entries(cfg: Dict[str, Any], lam: float, seed: int) -> List[ClusterMethodSpec]:
    defaults = cfg.get("ifc", {})
    base = {
        "jl_dim": defaults.get("jl_dim", 256),
        "clusters": defaults.get("clusters", 64),
        "umap_dim": defaults.get("umap_dim", defaults.get("umap_n_components", 25)),
        "use_umap": defaults.get("use_umap", True),
        "max_cg_iters": defaults.get("max_cg_iters", defaults.get("cg_max_iters", 512)),
        "tol": defaults.get("tol", defaults.get("cg_tol", 1e-6)),
        "fisher": defaults.get("fisher", True),
        "normalize": defaults.get("normalize", True),
        "collect_diagnostics": defaults.get("collect_diagnostics", False),
    }
    entries = cfg.get("cluster_methods", ["kmeans", "random", "kmedian"])
    specs: List[ClusterMethodSpec] = []
    for idx, entry in enumerate(entries):
        if isinstance(entry, str):
            entry = {"name": entry, "method": entry}
        name = entry.get("name") or f"method_{idx}"
        method = entry.get("method", name)
        merged = base.copy()
        merged.update(entry)
        jl_dim = int(merged.get("jl_dim", base["jl_dim"]))
        clusters = int(max(1, merged.get("clusters", base["clusters"])))
        cfg_obj = IFCompressedCfg(
            name=f"paper_ifc_{name}",
            jl_dim=jl_dim,
            clusters=clusters,
            cluster_method=str(method),
            damping=float(lam),
            max_cg_iters=int(merged.get("max_cg_iters", base["max_cg_iters"])),
            tol=float(merged.get("tol", base["tol"])),
            fisher=bool(merged.get("fisher", base["fisher"])),
            normalize=bool(merged.get("normalize", base["normalize"])),
            recourse_steps=0,
            cache_grads=True,
            seed=int(merged.get("seed", seed + idx)),
            use_umap=bool(merged.get("use_umap", base["use_umap"])),
            umap_n_components=int(merged.get("umap_dim", base["umap_dim"])),
            collect_diagnostics=bool(merged.get("collect_diagnostics", base["collect_diagnostics"])),
        )
        specs.append(ClusterMethodSpec(name=name, method=method, cfg=cfg_obj, description=entry.get("description", "")))
    return specs


def _summaries_from_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    summaries: Dict[Tuple[str, str, int, int], Dict[str, Any]] = {}
    for row in rows:
        key = (
            row["scenario"],
            row["cluster_method"],
            int(row["lambda_index"]),
            int(row["K"]),
            int(row.get("clusters", 0)),
            int(row.get("jl_dim", 0)),
        )
        buf = summaries.setdefault(key, {
            "scenario": row["scenario"],
            "cluster_method": row["cluster_method"],
            "K": row["K"],
            "lambda": row["lambda"],
            "clusters": row.get("clusters", 0),
            "jl_dim": row.get("jl_dim", 0),
            "loss_delta_if_mean": [],
            "loss_delta_ifc_mean": [],
            "loss_delta_cluster_mean": [],
            "clustering_error_norm": [],
            "solver_error_norm": [],
            "total_error_norm": [],
            "loss_clustering_error": [],
            "loss_solver_error": [],
            "loss_total_error": [],
            "success_if": [],
            "success_ifc": [],
        })
        buf["loss_delta_if_mean"].append(row["d_loss_if"])
        buf["loss_delta_ifc_mean"].append(row["d_loss_ifc"])
        buf["loss_delta_cluster_mean"].append(row["d_loss_cluster_only"])
        buf["clustering_error_norm"].append(row.get("clustering_error_norm", float("nan")))
        buf["solver_error_norm"].append(row.get("solver_error_norm", float("nan")))
        buf["total_error_norm"].append(row.get("total_error_norm", float("nan")))
        buf["loss_clustering_error"].append(row.get("loss_clustering_error", float("nan")))
        buf["loss_solver_error"].append(row.get("loss_solver_error", float("nan")))
        buf["loss_total_error"].append(row.get("loss_total_error", float("nan")))
        buf["success_if"].append(row.get("success_if", 0))
        buf["success_ifc"].append(row.get("success_ifc", 0))
    packed: List[Dict[str, Any]] = []
    for stats in summaries.values():
        def _safe_mean(values: List[float]) -> float:
            arr = np.asarray(values, dtype=float)
            arr = arr[np.isfinite(arr)]
            if arr.size == 0:
                return float("nan")
            return float(arr.mean())
        packed.append({
            "scenario": stats["scenario"],
            "cluster_method": stats["cluster_method"],
            "lambda": stats["lambda"],
            "K": stats["K"],
            "clusters": int(stats.get("clusters", 0)),
            "jl_dim": int(stats.get("jl_dim", 0)),
            "mean_d_loss_if": _safe_mean(stats["loss_delta_if_mean"]),
            "mean_d_loss_ifc": _safe_mean(stats["loss_delta_ifc_mean"]),
            "mean_d_loss_cluster": _safe_mean(stats["loss_delta_cluster_mean"]),
            "clustering_error_norm": _safe_mean(stats["clustering_error_norm"]),
            "solver_error_norm": _safe_mean(stats["solver_error_norm"]),
            "total_error_norm": _safe_mean(stats["total_error_norm"]),
            "mean_loss_clustering_error": _safe_mean(stats["loss_clustering_error"]),
            "mean_loss_solver_error": _safe_mean(stats["loss_solver_error"]),
            "mean_loss_total_error": _safe_mean(stats["loss_total_error"]),
            "success_rate_if": _safe_mean(stats["success_if"]),
            "success_rate_ifc": _safe_mean(stats["success_ifc"]),
        })
    return packed


def _write_outputs(output_dir: Path, rows: List[Dict[str, Any]], summaries: List[Dict[str, Any]], metadata: Dict[str, Any]) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    with open(output_dir / "per_fold.json", "w") as f:
        json.dump([_serialize(r) for r in rows], f, indent=2)
    with open(output_dir / "settings_summary.json", "w") as f:
        json.dump([_serialize(s) for s in summaries], f, indent=2)
    with open(output_dir / "metadata_checkpoint.json", "w") as f:
        json.dump(_serialize(metadata), f, indent=2)

    # Group-query benchmark outputs
    group_scores = metadata.get("group_scores", []) if isinstance(metadata, dict) else []
    if isinstance(group_scores, list) and group_scores:
        with open(output_dir / "group_scores.json", "w") as f:
            json.dump([_serialize(r) for r in group_scores], f, indent=2)

    # Optional: poison detection outputs (write explicit logs beyond metadata checkpoint).
    poisoning = metadata.get("poisoning", []) if isinstance(metadata, dict) else []
    if isinstance(poisoning, list) and poisoning:
        with open(output_dir / "poisoning.json", "w") as f:
            json.dump([_serialize(r) for r in poisoning], f, indent=2)
        try:
            import csv
            import json as _json

            def _flatten_poison_row(r: Dict[str, Any]) -> Dict[str, Any]:
                out: Dict[str, Any] = {str(k): v for k, v in dict(r).items()}
                patk = out.pop("precision_at_k", None)
                if isinstance(patk, dict):
                    for kk, vv in patk.items():
                        out[f"precision_at_k_{kk}"] = vv
                # Ensure nested values are stringified for CSV safety.
                for k, v in list(out.items()):
                    if isinstance(v, (dict, list, tuple)):
                        out[k] = _json.dumps(v)
                return out

            flat: List[Dict[str, Any]] = []
            for r in poisoning:
                sr = _serialize(r)
                if isinstance(sr, dict):
                    flat.append(_flatten_poison_row(sr))
            fieldnames = sorted({k for r in flat for k in r.keys()})
            with open(output_dir / "poisoning.csv", "w", newline="") as fcsv:
                w = csv.DictWriter(fcsv, fieldnames=fieldnames)
                w.writeheader()
                w.writerows(flat)
        except Exception:
            pass

        # Convenience: group-only metrics table
        try:
            group_only = [r for r in poisoning if isinstance(r, dict) and str(r.get("level", "")) == "group"]
            if group_only:
                flat: List[Dict[str, Any]] = []
                for r in group_only:
                    sr = _serialize(r)
                    if not isinstance(sr, dict):
                        continue
                    rr: Dict[str, Any] = {str(k): v for k, v in sr.items()}
                    patk = rr.pop("precision_at_k", None)
                    if isinstance(patk, dict):
                        for kk, vv in patk.items():
                            rr[f"precision_at_k_{kk}"] = vv
                    for k, v in list(rr.items()):
                        if isinstance(v, (dict, list, tuple)):
                            rr[k] = json.dumps(v)
                    flat.append(rr)
                fieldnames = sorted({k for r in flat for k in r.keys()})
                with open(output_dir / "group_poisoning.csv", "w", newline="") as fcsv:
                    w = csv.DictWriter(fcsv, fieldnames=fieldnames)
                    w.writeheader()
                    w.writerows(flat)
        except Exception:
            pass

    baseline_diag = metadata.get("baseline_score_diagnostics", []) if isinstance(metadata, dict) else []
    if isinstance(baseline_diag, list) and baseline_diag:
        with open(output_dir / "baseline_score_diagnostics.json", "w") as f:
            json.dump([_serialize(r) for r in baseline_diag], f, indent=2)
        try:
            import csv
            import json as _json

            flat: List[Dict[str, Any]] = []
            for r in baseline_diag:
                sr = _serialize(r)
                if not isinstance(sr, dict):
                    continue
                rr: Dict[str, Any] = {str(k): v for k, v in sr.items()}
                for k, v in list(rr.items()):
                    if isinstance(v, (dict, list, tuple)):
                        rr[k] = _json.dumps(v)
                flat.append(rr)
            fieldnames = sorted({k for r in flat for k in r.keys()})
            with open(output_dir / "baseline_score_diagnostics.csv", "w", newline="") as fcsv:
                w = csv.DictWriter(fcsv, fieldnames=fieldnames)
                w.writeheader()
                w.writerows(flat)
        except Exception:
            pass

    sanity = metadata.get("poisoning_sanity", []) if isinstance(metadata, dict) else []
    if isinstance(sanity, list) and sanity:
        with open(output_dir / "poisoning_sanity.json", "w") as f:
            json.dump([_serialize(r) for r in sanity], f, indent=2)
        try:
            import csv
            import json as _json

            flat: List[Dict[str, Any]] = []
            for r in sanity:
                sr = _serialize(r)
                if not isinstance(sr, dict):
                    continue
                rr: Dict[str, Any] = {str(k): v for k, v in sr.items()}
                for k, v in list(rr.items()):
                    if isinstance(v, (dict, list, tuple)):
                        rr[k] = _json.dumps(v)
                flat.append(rr)
            fieldnames = sorted({k for r in flat for k in r.keys()})
            with open(output_dir / "poisoning_sanity.csv", "w", newline="") as fcsv:
                w = csv.DictWriter(fcsv, fieldnames=fieldnames)
                w.writeheader()
                w.writerows(flat)
        except Exception:
            pass

    # Optional retraining benchmark outputs
    retr = metadata.get("retraining", []) if isinstance(metadata, dict) else []
    if isinstance(retr, list) and retr:
        with open(output_dir / "retraining_benchmark.json", "w") as f:
            json.dump([_serialize(r) for r in retr], f, indent=2)
        # Flatten a small summary CSV for easy ranking.
        try:
            import csv

            keys = [
                "scenario",
                "scenario_index",
                "lambda",
                "lambda_index",
                "method",
                "metric",
                "auc_improvement",
                "improvement_at_target",
                "runtime_sec",
            ]
            with open(output_dir / "retraining_benchmark_summary.csv", "w", newline="") as fcsv:
                w = csv.DictWriter(fcsv, fieldnames=keys)
                w.writeheader()
                for r in retr:
                    w.writerow({k: r.get(k, "") for k in keys})
        except Exception:
            pass

        # Convenience: group-only repair table
        try:
            group_repair = [r for r in retr if isinstance(r, dict) and str(r.get("removal_unit", "")) == "group"]
            if group_repair:
                flat: List[Dict[str, Any]] = []
                for r in group_repair:
                    sr = _serialize(r)
                    if not isinstance(sr, dict):
                        continue
                    rr: Dict[str, Any] = {str(k): v for k, v in sr.items()}
                    for k, v in list(rr.items()):
                        if isinstance(v, (dict, list, tuple)):
                            rr[k] = json.dumps(v)
                    flat.append(rr)
                fieldnames = sorted({k for r in flat for k in r.keys()})
                with open(output_dir / "group_repair.csv", "w", newline="") as fcsv:
                    w = csv.DictWriter(fcsv, fieldnames=fieldnames)
                    w.writeheader()
                    w.writerows(flat)
        except Exception:
            pass
    if summaries:
        fieldnames = sorted({k for s in summaries for k in s.keys()})
        with open(output_dir / "settings_summary.csv", "w", newline="") as f:
            import csv

            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(summaries)


def run_poison_experiments(
    cfg: Dict[str, Any],
    logger: Optional[logging.Logger] = None,
    device: Optional[Union[str, torch.device]] = None,
) -> Dict[str, Any]:
    """Orchestrate paper-ready experiments (error decomposition, clustering checks)."""

    outdir = cfg.get("outdir", "runs")
    name = cfg.get("name", "paper_run")
    path = Path(outdir) / name
    path.mkdir(parents=True, exist_ok=True)
    log = logger or setup_logging(cfg.get("log_level", "DEBUG"), log_dir=str(path))
    log.info("Paper experiment start")
    log.info("Config: %s", cfg)

    seed = int(cfg.get("seed", 0))
    set_seed(seed)

    base_model: nn.Module = cfg["model"]
    dataset = cfg["dataset"]
    loss_fn: nn.Module = cfg["loss_fn"]
    if device is not None:
        base_model = base_model.to(device)
    labels_arr = _extract_labels(dataset, cfg.get("labels"))
    base_labels_np = np.asarray(labels_arr) if labels_arr is not None else None

    scenarios = cfg.get("scenarios", cfg.get("poisoning_scenarios", [{"name": "clean"}]))
    if not scenarios:
        scenarios = [{"name": "clean"}]

    damping_grid = list(cfg.get("damping_grid", cfg.get("damping", [1e-3])))
    if isinstance(damping_grid, float):
        damping_grid = [float(damping_grid)]
    Ks = list(cfg.get("Ks", [10]))
    topK = list(cfg.get("TopK", [1, 5, 10, 50, 100, 200, 500, 1000]))

    debug_cfg = cfg.get("debugging", {})
    baseline_cfg = cfg.get("baselines", {})
    report_topk = list(debug_cfg.get("report_topk", topK))
    val_fraction = float(debug_cfg.get("val_fraction", cfg.get("val_fraction", 0.0)))

    group_task_cfg = dict(cfg.get("group_task", {}))
    group_task_enabled = bool(group_task_cfg.get("enabled", False))

    trak_cfg = dict(cfg.get("trak", {}))
    trak_enabled = bool(trak_cfg.get("enable_trak", False))

    train_cfg = cfg.get("train", {"epochs": 1, "lr": 1e-2, "batch_size": 128})
    batch_size = int(cfg.get("batch_size", train_cfg.get("batch_size", 128)))
    max_grad_norm = train_cfg.get("max_grad_norm", train_cfg.get("grad_clip", None))
    if max_grad_norm is not None:
        try:
            max_grad_norm = float(max_grad_norm)
        except Exception:
            max_grad_norm = None
        if max_grad_norm is not None and max_grad_norm <= 0:
            max_grad_norm = None

    all_rows: List[Dict[str, Any]] = []
    scenario_reports: List[Dict[str, Any]] = []
    correlation_reports: List[Dict[str, Any]] = []
    poisoning_reports: List[Dict[str, Any]] = []
    baseline_score_diagnostics: List[Dict[str, Any]] = []  # kept for backward-compat; unused in group task.
    poisoning_sanity: List[Dict[str, Any]] = []
    retraining_reports: List[Dict[str, Any]] = []
    group_scores_records: List[Dict[str, Any]] = []
    t_start = time.time()

    for scenario_idx, scenario_cfg in enumerate(scenarios):
        scen_name = scenario_cfg.get("name", f"scenario_{scenario_idx}")
        log.info("Preparing scenario %s (%d/%d)", scen_name, scenario_idx + 1, len(scenarios))
        scen_rng = np.random.default_rng(seed + scenario_idx)

        backdoor_cfg = dict(scenario_cfg.get("backdoor", {})) if isinstance(scenario_cfg.get("backdoor", {}), dict) else {}
        backdoor_enabled = bool(backdoor_cfg.get("enabled", False))
        backdoor_mode = str(backdoor_cfg.get("mode", "fixedV")).lower() if backdoor_enabled else ""

        scenario_has_mutation = bool(scenario_cfg.get("poison")) or bool(scenario_cfg.get("relabel_map"))
        if group_task_enabled:
            scenario_has_mutation = True
        eff_val_fraction = float(val_fraction)
        if scenario_has_mutation and eff_val_fraction <= 0.0:
            eff_val_fraction = 0.1

        # IMPORTANT: split first on clean labels, then apply label mutations ONLY to train.
        test_fraction = float(cfg.get("test_fraction", 0.1))
        if test_fraction > 0.0:
            train_idx, val_idx, test_idx = _split_train_val_test_indices(
                len(dataset),
                base_labels_np,
                eff_val_fraction,
                test_fraction,
                seed ,
            )
        else:
            train_idx, val_idx = _split_train_val_indices(
                len(dataset),
                base_labels_np,
                eff_val_fraction,
                seed + scenario_idx,
            )
            test_idx = np.empty(0, dtype=np.int64)

        # --- Build groups and apply train-only poisoning ---
        poison_meta: Dict[str, Any] = {}
        groups_orig: List[np.ndarray] = []
        groups_train_local: List[np.ndarray] = []
        group_id_local = np.empty(0, dtype=np.int64)

        if group_task_enabled:
            G = int(group_task_cfg.get("G", 500))
            grouping = str(group_task_cfg.get("grouping", "random"))
            groups_orig, groups_train_local, group_id_local = make_groups(
                train_idx=train_idx,
                labels_clean=base_labels_np,
                G=G,
                grouping=grouping,
                rng=scen_rng,
            )
            poison_groups_cfg = dict(group_task_cfg.get("poison_groups", {}))
            poisoned_group_ids = select_poison_groups(groups_orig, base_labels_np, poison_groups_cfg, scen_rng)
            if backdoor_enabled:
                # Backdoor uses the same group selection policy as poison_groups.
                bd_cfg_eff = dict(backdoor_cfg)
                bd_cfg_eff.setdefault("within_group_fraction", float(poison_groups_cfg.get("within_group_fraction", 1.0)))
                ds_train_variant, labels_variant, poison_meta = _apply_group_backdoor_train_only(
                    dataset=dataset,
                    base_labels=base_labels_np,
                    train_idx=train_idx,
                    groups_orig=groups_orig,
                    poisoned_group_ids=poisoned_group_ids,
                    backdoor_cfg=bd_cfg_eff,
                    rng=scen_rng,
                )
            else:
                ds_train_variant, labels_variant, poison_meta = _apply_group_label_mutations_train_only(
                    dataset=dataset,
                    base_labels=base_labels_np,
                    train_idx=train_idx,
                    groups_orig=groups_orig,
                    poisoned_group_ids=poisoned_group_ids,
                    poison_cfg=poison_groups_cfg,
                    rng=scen_rng,
                )
            poison_meta.update({
                "level": "group",
                "G": int(len(groups_orig)),
                "grouping": grouping,
                "poison_fraction_groups": float(len(poisoned_group_ids) / max(1, len(groups_orig))),
            })
        else:
            # Legacy scenario poison (point-level) still supported, but the new benchmark expects group_task.enabled.
            ds_train_variant, labels_variant, poison_meta = _apply_label_mutations_train_only(
                dataset=dataset,
                base_labels=base_labels_np,
                scenario_cfg=scenario_cfg,
                rng=scen_rng,
                train_idx=train_idx,
            )
        log.info(
            "Scenario %s prepared (poison frac total %.3f, train %.3f)",
            scen_name,
            poison_meta.get("poisoned_fraction", 0.0),
            poison_meta.get("poisoned_fraction_train", 0.0),
        )

        train_labels_variant = labels_variant[train_idx] if labels_variant is not None else None
        train_dataset: Dataset = Subset(ds_train_variant, train_idx.tolist())


        # Keep validation/test CLEAN (original dataset).
        # Also create an explicit clean view of the training split for group-query evaluation sets V_g.
        dataset_clean: Dataset = dataset
        if base_labels_np is not None:
            dataset_clean = LabelOverrideDataset(dataset, base_labels_np)
        train_clean_dataset: Dataset = Subset(dataset_clean, train_idx.tolist())
        val_dataset: Optional[Dataset] = Subset(dataset_clean, val_idx.tolist()) if val_idx.size else None
        test_dataset: Optional[Dataset] = Subset(dataset_clean, test_idx.tolist()) if test_idx.size else None
        log.info(
            "Scenario %s: split train/val/test (train=%d, val=%d, test=%d, val_fraction=%.3f, test_fraction=%.3f)",
            scen_name,
            len(train_dataset),
            len(val_dataset) if val_dataset is not None else 0,
            len(test_dataset) if test_dataset is not None else 0,
            eff_val_fraction,
            test_fraction,
        )
        # NOTE: TracIn/FastIF/checkpoint machinery intentionally removed.

        # Keep a reproducible init snapshot for retraining interventions.
        init_model = copy.deepcopy(base_model)
        init_state = copy.deepcopy(init_model.state_dict())

        model = copy.deepcopy(base_model)
        if device is not None:
            model = model.to(device)
        log.info(
            "Scenario %s: training model (%d epochs, lr=%g, batch_size=%d)",
            scen_name,
            int(train_cfg.get("epochs", 1)),
            float(train_cfg.get("lr", 1e-2)),
            int(train_cfg.get("batch_size", batch_size)),
        )

        # --- Checkpoint policy (used by TRAK) ---
        checkpoint_dir = None
        checkpoint_epochs = None
        checkpoint_fractions = None
        save_initial_ckpt = False
        if trak_enabled:
            # Ensure at least ~3 checkpoints when possible (initial + 2+ during training).
            save_initial_ckpt = bool(trak_cfg.get("save_initial", True))
            ckpt_policy = str(trak_cfg.get("checkpoint_policy", "fractions")).lower()
            ckpt_dir = trak_cfg.get("save_dir", None)
            if ckpt_dir is None:
                ckpt_dir = str(Path(path) / scen_name / "trak")
            ckpt_dir = str(Path(ckpt_dir) / "checkpoints")
            checkpoint_dir = ckpt_dir

            # Default fractions: 25/50/75/100%.
            if ckpt_policy in {"fractions", "epochs"}:
                if ckpt_policy == "epochs":
                    checkpoint_epochs = trak_cfg.get("checkpoint_epochs", None)
                else:
                    checkpoint_fractions = trak_cfg.get("checkpoint_fractions", [0.25, 0.5, 0.75, 1.0])
            # "existing" means the user provided checkpoint_paths explicitly; if missing, still create some.
            if ckpt_policy == "existing" and not trak_cfg.get("checkpoint_paths"):
                checkpoint_fractions = trak_cfg.get("checkpoint_fractions", [0.25, 0.5, 0.75, 1.0])
        t_train0 = time.time()
        train_stats = train_model(
            model,
            train_dataset,
            loss_fn,
            epochs=int(train_cfg.get("epochs", 1)),
            lr=float(train_cfg.get("lr", 1e-2)),
            weight_decay=float(train_cfg.get("weight_decay", 0.0)),
            batch_size=int(train_cfg.get("batch_size", batch_size)),
            num_workers=int(train_cfg.get("num_workers", 0)),
            max_grad_norm=max_grad_norm,
            stop_on_nonfinite=True,
            logger=log,
            checkpoint_dir=checkpoint_dir,
            checkpoint_epochs=checkpoint_epochs,
            checkpoint_fractions=checkpoint_fractions,
            save_initial=save_initial_ckpt,
        )
        checkpoint_paths = list(train_stats.get("checkpoint_paths", [])) if isinstance(train_stats, dict) else []
        time_train = time.time() - t_train0
        log.info("Scenario %s: training finished in %.2fs", scen_name, time_train)
        model.eval()

        env_meta = {
            "scenario": scen_name,
            "seed": seed,
            "train_time": time_train,
            "dataset_stats": _compute_dataset_stats(train_dataset, train_labels_variant),
            "train_size": len(train_dataset),
            "val_size": len(val_dataset) if val_dataset is not None else 0,
            "test_size": len(test_dataset) if test_dataset is not None else 0,
            "poisoning": poison_meta,
            "Ks": Ks,
            "damping_grid": damping_grid,
            "group_task": group_task_cfg if group_task_enabled else {"enabled": False},
            "trak": trak_cfg if trak_enabled else {"enabled": False},
            "backdoor": {"enabled": bool(backdoor_enabled), **backdoor_cfg} if backdoor_enabled else {"enabled": False},
        }

        target_device = next(model.parameters()).device
        cached_train_dataset = _cache_dataset_to_device(
            train_dataset,
            batch_size=batch_size,
            device=target_device,
            num_workers=int(train_cfg.get("num_workers", 0)),
        )

        cached_train_clean_dataset = _cache_dataset_to_device(
            train_clean_dataset,
            batch_size=batch_size,
            device=target_device,
            num_workers=int(train_cfg.get("num_workers", 0)),
        )
        cached_val_dataset = None
        if val_dataset is not None and len(val_dataset) > 0:
            cached_val_dataset = _cache_dataset_to_device(
                val_dataset,
                batch_size=batch_size,
                device=target_device,
                num_workers=int(train_cfg.get("num_workers", 0)),
            )

        cached_test_dataset = None
        if test_dataset is not None and len(test_dataset) > 0:
            cached_test_dataset = _cache_dataset_to_device(
                test_dataset,
                batch_size=batch_size,
                device=target_device,
                num_workers=int(train_cfg.get("num_workers", 0)),
            )

        # Backdoor target datasets (triggered inputs, constant target label).
        cached_val_bd_target_dataset = None
        cached_test_bd_target_dataset = None
        if backdoor_enabled:
            target_label = backdoor_cfg.get("target_label", None)
            trigger_spec = backdoor_cfg.get("trigger", None)
            if target_label is None or trigger_spec is None:
                raise ValueError("backdoor.enabled=true requires backdoor.target_label and backdoor.trigger")
            if val_dataset is not None and len(val_dataset) > 0:
                val_bd = ConstantLabelDataset(TriggerAllDataset(val_dataset, trigger_spec), int(target_label))
                cached_val_bd_target_dataset = _cache_dataset_to_device(
                    val_bd,
                    batch_size=batch_size,
                    device=target_device,
                    num_workers=int(train_cfg.get("num_workers", 0)),
                )
            if test_dataset is not None and len(test_dataset) > 0:
                test_bd = ConstantLabelDataset(TriggerAllDataset(test_dataset, trigger_spec), int(target_label))
                cached_test_bd_target_dataset = _cache_dataset_to_device(
                    test_bd,
                    batch_size=batch_size,
                    device=target_device,
                    num_workers=int(train_cfg.get("num_workers", 0)),
                )

        train_loader = DataLoader(cached_train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
        val_loader = (
            DataLoader(cached_val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
            if cached_val_dataset is not None
            else train_loader
        )
        val_loader_bd_target = (
            DataLoader(cached_val_bd_target_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
            if cached_val_bd_target_dataset is not None
            else None
        )
        test_loader_bd_target = (
            DataLoader(cached_test_bd_target_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
            if cached_test_bd_target_dataset is not None
            else None
        )
        n_train = len(train_dataset)

        diagnostics_res = {}

        # Build group index helpers using local indices into the cached training datasets.
        G_eff = int(len(groups_train_local))
        report_topk_groups = list(group_task_cfg.get("report_topk_groups", [1, 5, 10, 20, 50]))
        eval_mode = str(group_task_cfg.get("eval_mode", "self")).lower()
        neighbor_k = int(group_task_cfg.get("neighbor_k", 0))

        # Backdoor-specific eval_mode overrides.
        if backdoor_enabled and backdoor_mode in {"fixedv", "fixed_v", "fixed"}:
            eval_mode = "val_slice"
        if backdoor_enabled and backdoor_mode in {"multiv", "multi_v", "multi"}:
            eval_mode = "val_groups"

        # Multi-V: build a held-out partition of the (clean) val split, then apply trigger+target label.
        val_groups_local: List[np.ndarray] = []
        val_neighbor_k = int(backdoor_cfg.get("V_neighbor_k", 0)) if backdoor_enabled else 0
        if eval_mode == "val_groups":
            if val_dataset is None or len(val_dataset) == 0:
                raise RuntimeError("eval_mode=val_groups requires a non-empty validation split")
            if cached_val_bd_target_dataset is None:
                raise RuntimeError("eval_mode=val_groups requires cached_val_bd_target_dataset")
            v_partition = str(backdoor_cfg.get("V_partition", "contiguous"))
            val_groups_local = make_index_groups(len(val_dataset), G_eff, v_partition, scen_rng)

        # Baseline policy: IFOneSolve/TRAK only for fixedV (or non-backdoor).
        enable_if_one_solve_effective = bool(baseline_cfg.get("enable_if_one_solve", False))
        trak_enabled_effective = bool(trak_enabled)
        if backdoor_enabled and eval_mode != "val_slice":
            enable_if_one_solve_effective = False
        # NOTE: do NOT disable TRAK for multiV; TRAK must be conditioned on V_g.

        # Store backdoor metrics once per scenario (used by repair plots).
        if backdoor_enabled:
            bd_split = str(cfg.get("retraining_benchmark", {}).get("backdoor_eval_split", "test")).lower()
            bd_loader = test_loader_bd_target if bd_split == "test" else val_loader_bd_target
            clean_loader = (
                DataLoader(cached_test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
                if (bd_split == "test" and cached_test_dataset is not None)
                else val_loader
            )
            bd_metrics = {}
            if bd_loader is not None:
                bd_metrics.update({
                    f"base_asr_{bd_split}": float(_eval_loss_acc_over_loader(model, bd_loader, loss_fn, target_device).get("accuracy", float("nan")))
                })
            if clean_loader is not None:
                bd_metrics.update({
                    f"base_clean_acc_{bd_split}": float(_eval_loss_acc_over_loader(model, clean_loader, loss_fn, target_device).get("accuracy", float("nan")))
                })
            env_meta["backdoor_metrics"] = bd_metrics

        # --- TRAK: fixedV scores once; multiV scores conditioned on V_g ---
        trak_group_scores: Optional[np.ndarray] = None
        trak_prep_sec: Optional[float] = None
        trak_targets_total_sec: float = 0.0
        if trak_enabled_effective:
            try:
                t_trak0 = time.time()
                if len(checkpoint_paths) < 3:
                    log.warning(
                        "Scenario %s: TRAK enabled but only %d checkpoints available; consider epochs>=2 and save_initial=True.",
                        scen_name,
                        len(checkpoint_paths),
                    )
                trak_save_dir = trak_cfg.get("save_dir", None)
                if trak_save_dir is None:
                    trak_save_dir = str(Path(path) / scen_name / "trak")
                trak_cfg_effective = dict(trak_cfg)
                trak_cfg_effective["save_dir"] = trak_save_dir

                # Pass checkpoints through cfg to keep TRAKBaseline.prepare() signature consistent.
                trak_cfg_effective["checkpoint_paths"] = list(trak_cfg.get("checkpoint_paths", checkpoint_paths))
                trak_cfg_effective["exp_name"] = f"{scen_name}_val_bd" if (backdoor_enabled and eval_mode == "val_slice") else f"{scen_name}_val"

                trak = TRAKBaseline(checkpoint_paths=list(trak_cfg_effective["checkpoint_paths"]))

                # For prepare(), pass any target loader; for multiV we will override.
                prepare_loader = (
                    val_loader_bd_target
                    if (backdoor_enabled and cached_val_bd_target_dataset is not None)
                    else val_loader
                )
                trak.prepare(
                    model=model,
                    train_loader=train_loader,
                    val_loader=prepare_loader,
                    loss_fn=loss_fn,
                    cg_mod=None,
                    caches={},
                    device=target_device,
                    cfg={"trak": trak_cfg_effective},
                    folds=None,
                )

                grp = np.full((G_eff,), np.nan, dtype=float)

                if eval_mode == "val_slice":
                    fixed_loader = (
                        val_loader_bd_target
                        if (backdoor_enabled and val_loader_bd_target is not None)
                        else val_loader
                    )
                    t1 = time.time()
                    point_scores = trak.score_points(val_loader=fixed_loader)
                    trak_targets_total_sec += float(time.time() - t1)

                    for g in range(G_eff):
                        idxs = np.asarray(groups_train_local[g], dtype=np.int64)
                        if idxs.size == 0:
                            continue
                        vals = np.asarray(point_scores, dtype=float)[idxs]
                        m = np.isfinite(vals)
                        if np.any(m):
                            grp[g] = float(np.sum(vals[m]))

                elif eval_mode == "val_groups":
                    if cached_val_bd_target_dataset is None:
                        raise RuntimeError("TRAK multiV requires cached_val_bd_target_dataset")

                    def _val_group_indices(gg: int) -> np.ndarray:
                        if not val_groups_local:
                            return np.empty(0, dtype=np.int64)
                        if val_neighbor_k <= 0:
                            return np.asarray(val_groups_local[gg], dtype=np.int64)
                        lo = max(0, int(gg) - val_neighbor_k)
                        hi = min(G_eff - 1, int(gg) + val_neighbor_k)
                        merged: List[int] = []
                        for kk in range(lo, hi + 1):
                            merged.extend(int(x) for x in np.asarray(val_groups_local[kk], dtype=np.int64).tolist())
                        merged_arr = np.unique(np.asarray(merged, dtype=np.int64))
                        merged_arr.sort()
                        return merged_arr

                    for g in range(G_eff):
                        V_local = _val_group_indices(g)
                        if V_local.size == 0:
                            continue
                        loader_Vg = DataLoader(
                            Subset(cached_val_bd_target_dataset, V_local.tolist()),
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0,
                        )
                        t1 = time.time()
                        point_scores_g = trak.score_points(val_loader=loader_Vg)
                        trak_targets_total_sec += float(time.time() - t1)

                        S_local = np.asarray(groups_train_local[g], dtype=np.int64)
                        if S_local.size == 0:
                            continue
                        vals = np.asarray(point_scores_g, dtype=float)[S_local]
                        m = np.isfinite(vals)
                        if np.any(m):
                            grp[g] = float(np.sum(vals[m]))

                else:
                    # Non-backdoor usage: keep the fixed target from prepare().
                    t1 = time.time()
                    point_scores = trak.score_points()
                    trak_targets_total_sec += float(time.time() - t1)
                    for g in range(G_eff):
                        idxs = np.asarray(groups_train_local[g], dtype=np.int64)
                        if idxs.size == 0:
                            continue
                        vals = np.asarray(point_scores, dtype=float)[idxs]
                        m = np.isfinite(vals)
                        if np.any(m):
                            grp[g] = float(np.sum(vals[m]))

                trak_group_scores = grp
                trak_prep_sec = float(time.time() - t_trak0)
                log.info("Scenario %s: TRAK prepared (proj_dim=%s, checkpoints=%d)", scen_name, trak_cfg_effective.get("proj_dim", 2048), len(checkpoint_paths))
            except Exception as exc:
                log.warning("Scenario %s: TRAK failed (%s)", scen_name, exc)
                trak_group_scores = None
                trak_prep_sec = None
                trak_targets_total_sec = float("nan")

        def _eval_local_indices_for_group(g: int) -> np.ndarray:
            if eval_mode == "self":
                return np.asarray(groups_train_local[g], dtype=np.int64)
            if eval_mode == "neighbor":
                if neighbor_k <= 0:
                    return np.asarray(groups_train_local[g], dtype=np.int64)
                lo = max(0, int(g) - neighbor_k)
                hi = min(G_eff - 1, int(g) + neighbor_k)
                merged: List[int] = []
                for gg in range(lo, hi + 1):
                    merged.extend(int(x) for x in np.asarray(groups_train_local[gg], dtype=np.int64).tolist())
                merged_arr = np.unique(np.asarray(merged, dtype=np.int64))
                merged_arr.sort()
                return merged_arr
            if eval_mode == "val_slice":
                # Fixed V competitor: use the global val_loader (clean) for all queries.
                return np.empty(0, dtype=np.int64)
            if eval_mode == "val_groups":
                if not val_groups_local:
                    return np.empty(0, dtype=np.int64)
                if val_neighbor_k <= 0:
                    return np.asarray(val_groups_local[g], dtype=np.int64)
                lo = max(0, int(g) - val_neighbor_k)
                hi = min(G_eff - 1, int(g) + val_neighbor_k)
                merged: List[int] = []
                for gg in range(lo, hi + 1):
                    merged.extend(int(x) for x in np.asarray(val_groups_local[gg], dtype=np.int64).tolist())
                merged_arr = np.unique(np.asarray(merged, dtype=np.int64))
                merged_arr.sort()
                return merged_arr
            raise ValueError(f"Unknown group_task.eval_mode: {eval_mode}")

        for lam_idx, lam in enumerate(damping_grid):
            log.info("Scenario %s: damping=%.3e", scen_name, lam)
            objective = _PaperObjective(loss_fn, weight_decay=float(train_cfg.get("weight_decay", 0.0)))
            mod = CGInfluenceModule(
                model=model,
                objective=objective,
                train_loader=train_loader,
                test_loader=val_loader,
                device=next(model.parameters()).device,
                damp=float(lam),
                gnh=cfg.get("fisher", True),
                maxiter=int(cfg.get("cg_max_iters", cfg.get("cg_iters", 512))),
                tol=float(cfg.get("cg_tol", cfg.get("tol", 1e-6))),
            )
            method_specs = _method_entries(cfg, lam, seed + scenario_idx * 17)
            caches: Dict[str, Dict[str, Any]] = {}
            cache_alignments: Dict[str, Dict[str, Any]] = {}
            cache_prep_sec_by_name: Dict[str, float] = {}
            for spec in method_specs:
                log.info("Building cache for %s/%s", scen_name, spec.name)
                t_cache0 = time.time()
                cache = build_ifc_cache(model, train_loader, loss_fn, spec.cfg, logger=log)
                cache_prep_sec_by_name[spec.name] = float(time.time() - t_cache0)
                caches[spec.name] = cache
                labels_np = np.asarray(cache.get("labels", []))
                alignment = _cluster_alignment_metrics(labels_np, train_labels_variant)
 
                cache_alignments[spec.name] = alignment
                cache_timings = cache.get("timings", {})
                cache_cluster_stats = cache.get("cluster_stats", {})
                _log_structured(log, f"cache_{scen_name}_{spec.name}", {
                    "alignments": alignment,
                    "timings": cache_timings,
                    "cluster_stats": cache_cluster_stats,
                })

            # -----------------------
            # Group-query, multi-query benchmark (Q = G)
            # -----------------------
            t_group0 = time.time()
            scores_by_method: Dict[str, np.ndarray] = {}
            timings_by_method: Dict[str, float] = {}
            prep_sec_by_method: Dict[str, float] = {}
            per_query_sec_total_by_method: Dict[str, float] = {}
            per_query_sec_mean_by_method: Dict[str, float] = {}

            # Define methods: IF (per-group IHVP), IFC (cache reuse), grad_sim, random.
            method_names: List[str] = []
            # if bool(baseline_cfg.get("enable_if", False)):
            #     method_names.append("IFOneSolve")
            # One-solve is intentionally not meaningful when V_g varies; keep only when eval_mode == val_slice.
            if enable_if_one_solve_effective and eval_mode == "val_slice":
                method_names.append("IFOneSolve")
            if bool(baseline_cfg.get("enable_random", True)):
                method_names.append("Random")
            method_names.append("GradSim")
            if trak_group_scores is not None:
                method_names.append("TRAK")
                # Set TRAK per-query cost upfront when applicable.
                per_query_sec_total_by_method.setdefault(
                    "TRAK",
                    float(trak_targets_total_sec) if eval_mode == "val_groups" else 0.0,
                )
            # IFC methods correspond to cache entries (one per cluster_method).
            for spec in method_specs:
                method_names.append(f"IFC::{spec.name}")

            # Precompute fixed val gradient for IFOneSolve if requested.
            fixed_g_V: Optional[torch.Tensor] = None
            fixed_ihvp: Optional[torch.Tensor] = None

            if eval_mode == "val_slice":
                t0 = time.time()
                fixed_loader = (
                    val_loader_bd_target
                    if (backdoor_enabled and val_loader_bd_target is not None)
                    else val_loader
                )
                fixed_g_V = _mean_grad_over_loader(model, fixed_loader, loss_fn, target_device)
                prep_sec_by_method["FixedV_gV"] = float(time.time() - t0)

                # Only IFOneSolve needs IHVP(g_V)
                if "IFOneSolve" in method_names:
                    t0 = time.time()
                    fixed_ihvp = mod.inverse_hvp(fixed_g_V).get("ihvp")
                    prep_sec_by_method["IFOneSolve"] = float(time.time() - t0)

            # Compute per-group g_S and g_V then score per method.
            rng_scores = np.random.default_rng(seed + 10_000 + scenario_idx * 1000 + lam_idx)
            scores_if = np.full((G_eff,), np.nan, dtype=float)
            scores_if_one = np.full((G_eff,), np.nan, dtype=float)
            scores_gradsim = np.full((G_eff,), np.nan, dtype=float)
            scores_random = rng_scores.standard_normal(G_eff).astype(float)
            scores_ifc_by_spec: Dict[str, np.ndarray] = {spec.name: np.full((G_eff,), np.nan, dtype=float) for spec in method_specs}

            t_if = 0.0
            t_ifc = {spec.name: 0.0 for spec in method_specs}
            t_gradsim = 0.0
            t_gs = 0.0
            t_gv = 0.0

            # Count how many queries we actually scored (non-empty groups).
            q_count = 0

            for g in range(G_eff):
                S_local = np.asarray(groups_train_local[g], dtype=np.int64)
                if S_local.size == 0:
                    continue
                q_count += 1

                # g_S computed on POISONED training labels (cached_train_dataset)
                t0 = time.time()
                loader_S = DataLoader(Subset(cached_train_dataset, S_local.tolist()), batch_size=batch_size, shuffle=False, num_workers=0)
                g_S = _mean_grad_over_loader(model, loader_S, loss_fn, target_device)
                t_gs += float(time.time() - t0)

                # g_V computed on CLEAN labels (cached_train_clean_dataset), varies per group.
                if eval_mode == "val_slice":
                    g_V = fixed_g_V
                elif eval_mode == "val_groups":
                    V_local = _eval_local_indices_for_group(g)
                    t0 = time.time()
                    loader_V = DataLoader(Subset(cached_val_bd_target_dataset, V_local.tolist()), batch_size=batch_size, shuffle=False, num_workers=0)
                    g_V = _mean_grad_over_loader(model, loader_V, loss_fn, target_device)
                    t_gv += float(time.time() - t0)
                else:
                    V_local = _eval_local_indices_for_group(g)
                    t0 = time.time()
                    loader_V = DataLoader(Subset(cached_train_clean_dataset, V_local.tolist()), batch_size=batch_size, shuffle=False, num_workers=0)
                    g_V = _mean_grad_over_loader(model, loader_V, loss_fn, target_device)
                    t_gv += float(time.time() - t0)

                if g_V is None:
                    continue

                denom = n_train - int(S_local.size)   # exact leave-out
                denom = max(1, denom)
                scale = float(S_local.size) / float(denom)  # because g_S is MEAN

                # GradSim baseline: -(1/N) g_V^T sum_{i in S} g_i == -( |S|/N ) g_V^T mean_g_S
                t0 = time.time()
                try:
                    scores_gradsim[g] = float((scale) * torch.dot(g_V.detach().view(-1), g_S.detach().view(-1)).item())
                except Exception:
                    scores_gradsim[g] = float("nan")
                t_gradsim += float(time.time() - t0)

                # IF: per-group IHVP solve
                if "IF" in method_names:
                    t0 = time.time()
                    sol = mod.inverse_hvp(g_S)
                    ihvp = sol.get("ihvp")
                    if ihvp is not None:
                        delta = scale * ihvp.detach()
                        try:
                            scores_if[g] = float(torch.dot(g_V.detach().view(-1), delta.view(-1)).item())
                        except Exception:
                            scores_if[g] = float("nan")
                    t_if += float(time.time() - t0)

                # IFOneSolve: fixed IHVP for fixed V (only meaningful when eval_mode == val_slice)
                if "IFOneSolve" in method_names and fixed_ihvp is not None:
                    # Reuse IHVP(fixed_g_V) but still need g_S per group.
                    # NOTE: this is a "fixed-V competitor" only; does NOT solve the per-query variant.
                    try:
                        delta_one = scale * fixed_ihvp.detach()
                        scores_if_one[g] = float(torch.dot(g_S.detach().view(-1), delta_one.view(-1)).item())
                    except Exception:
                        scores_if_one[g] = float("nan")

                # IFC: reuse cache; no IHVP here.
                for spec in method_specs:
                    t0 = time.time()
                    try:
                        out = delta_ifc_from_cache(caches[spec.name], S_local)
                        delta_hat = out.get("delta_theta") if isinstance(out, dict) else None
                        if delta_hat is None:
                            scores_ifc_by_spec[spec.name][g] = float("nan")
                        else:
                            scores_ifc_by_spec[spec.name][g] = float(torch.dot(g_V.detach().view(-1), delta_hat.detach().view(-1)).item())
                    except Exception:
                        scores_ifc_by_spec[spec.name][g] = float("nan")
                    t_ifc[spec.name] += float(time.time() - t0)

            # Pack methods
            if "IF" in method_names:
                scores_by_method["IF"] = scores_if
                timings_by_method["IF_inverse_hvp_sec"] = float(t_if)
                per_query_sec_total_by_method["IF"] = float(t_if)
            if "IFOneSolve" in method_names:
                scores_by_method["IFOneSolve"] = scores_if_one
            scores_by_method["GradSim"] = scores_gradsim
            timings_by_method["GradSim_sec"] = float(t_gradsim)
            per_query_sec_total_by_method["GradSim"] = float(t_gradsim)
            if "Random" in method_names:
                scores_by_method["Random"] = scores_random
            for spec in method_specs:
                scores_by_method[f"IFC::{spec.name}"] = scores_ifc_by_spec[spec.name]
                timings_by_method[f"IFC::{spec.name}_sec"] = float(t_ifc[spec.name])
                per_query_sec_total_by_method[f"IFC::{spec.name}"] = float(t_ifc[spec.name])

            # Prep times for baselines.
            for spec in method_specs:
                prep_sec_by_method[f"IFC::{spec.name}"] = float(cache_prep_sec_by_name.get(spec.name, float("nan")))
            if trak_group_scores is not None:
                prep_sec_by_method["TRAK"] = float(trak_prep_sec) if trak_prep_sec is not None else float("nan")
                # For multiV, TRAK target scoring is per-query; for fixedV treat as one-off.
                if "TRAK" not in per_query_sec_total_by_method:
                    per_query_sec_total_by_method["TRAK"] = float(trak_targets_total_sec) if eval_mode == "val_groups" else 0.0

            # Convert totals -> mean-per-query, leaving NaN when q_count==0.
            if q_count > 0:
                for mname, tot in per_query_sec_total_by_method.items():
                    per_query_sec_mean_by_method[mname] = float(tot) / float(q_count)
            else:
                for mname in per_query_sec_total_by_method.keys():
                    per_query_sec_mean_by_method[mname] = float("nan")

            if "TRAK" in method_names and trak_group_scores is not None:
                scores_by_method["TRAK"] = np.asarray(trak_group_scores, dtype=float)
                # Timing is set above (multiV) or defaults to 0 (fixedV).

            timings_by_method["compute_gS_sec"] = float(t_gs)
            timings_by_method["compute_gV_sec"] = float(t_gv)
            timings_by_method["group_benchmark_total_sec"] = float(time.time() - t_group0)
            timings_by_method["q_count"] = int(q_count)

            # Persist per-baseline timing breakdowns.
            timings_by_method["prep_sec_by_method"] = {k: float(v) for k, v in prep_sec_by_method.items()}
            timings_by_method["per_query_sec_total_by_method"] = {k: float(v) for k, v in per_query_sec_total_by_method.items()}
            timings_by_method["per_query_sec_mean_by_method"] = {k: float(v) for k, v in per_query_sec_mean_by_method.items()}

            # -----------------------
            # Group-level poison detection metrics
            # -----------------------
            poison_mask_groups = np.zeros((G_eff,), dtype=bool)
            poisoned_group_ids = [int(x) for x in poison_meta.get("poisoned_group_ids", [])] if isinstance(poison_meta, dict) else []
            for gg in poisoned_group_ids:
                if 0 <= gg < G_eff:
                    poison_mask_groups[gg] = True

            valid_topk_groups = [int(k) for k in report_topk_groups if int(k) > 0]
            poison_cfg_eval = cfg.get("poison_detection", {})
            candidate_only = bool(poison_cfg_eval.get("candidate_only", False))
            nonfinite_margin = float(poison_cfg_eval.get("nonfinite_margin", 1.0))

            # Track the sign used by poison_detection_report_v2 so retraining removes
            # groups in the same direction that yielded the reported metrics.
            chosen_sign_by_method: Dict[str, float] = {}

            for mname, scores in scores_by_method.items():
                report = _poison_detection_report_v2(
                    scores=np.asarray(scores, dtype=float),
                    poison_mask=poison_mask_groups,
                    topk_list=valid_topk_groups,
                    candidate_only=candidate_only,
                    nonfinite_margin=nonfinite_margin,
                )
                try:
                    cs = float(report.get("chosen_sign", 1.0))
                    if cs not in (-1.0, 1.0):
                        cs = 1.0
                except Exception:
                    cs = 1.0
                chosen_sign_by_method[str(mname)] = cs
                report.update({
                    "level": "group",
                    "G": int(G_eff),
                    "grouping": str(group_task_cfg.get("grouping", "random")),
                    "eval_mode": str(eval_mode),
                    "neighbor_k": int(neighbor_k),
                    "scenario": scen_name,
                    "scenario_index": int(scenario_idx),
                    "lambda": float(lam),
                    "lambda_index": int(lam_idx),
                    "method": str(mname),
                    "poison_fraction_groups": float(poison_mask_groups.mean()) if G_eff > 0 else 0.0,
                })
                poisoning_reports.append(report)

            # Random sanity check: AUROC~0.5, AUPRC~prevalence
            try:
                sanity = _poison_detection_report_v2(
                    scores=np.asarray(scores_random, dtype=float),
                    poison_mask=poison_mask_groups,
                    topk_list=[],
                    candidate_only=False,
                    nonfinite_margin=1.0,
                )
                poisoning_sanity.append(
                    {
                        "level": "group",
                        "scenario": scen_name,
                        "scenario_index": int(scenario_idx),
                        "lambda": float(lam),
                        "lambda_index": int(lam_idx),
                        "method": "__random__",
                        **{k: v for k, v in sanity.items() if k not in {"precision_at_k"}},
                    }
                )
                auroc_s = float(sanity.get("auroc", float("nan")))
                prev = float(sanity.get("prevalence_raw", float("nan")))
                if math.isfinite(auroc_s) and abs(auroc_s - 0.5) > 0.15:
                    log.warning("Scenario %s: group sanity(random) AUROC far from 0.5 (%.3f)", scen_name, auroc_s)
                auprc_s = float(sanity.get("auprc", float("nan")))
                if math.isfinite(auprc_s) and math.isfinite(prev) and abs(auprc_s - prev) > 0.15:
                    log.warning("Scenario %s: group sanity(random) AUPRC far from prevalence", scen_name)
            except Exception as exc:
                log.warning("Scenario %s: group random sanity failed (%s)", scen_name, exc)

            # Record raw group scores for plotting.
            group_scores_records.append(
                {
                    "scenario": scen_name,
                    "scenario_index": int(scenario_idx),
                    "lambda": float(lam),
                    "lambda_index": int(lam_idx),
                    "G": int(G_eff),
                    "group_sizes": [int(np.asarray(grp).size) for grp in groups_train_local],
                    "poisoned_group_ids": [int(x) for x in poisoned_group_ids],
                    "backdoor": {"enabled": bool(backdoor_enabled), "mode": str(backdoor_mode), "target_label": backdoor_cfg.get("target_label", None), "trigger": backdoor_cfg.get("trigger", None), "V_partition": backdoor_cfg.get("V_partition", None), "V_neighbor_k": int(backdoor_cfg.get("V_neighbor_k", 0))} if backdoor_enabled else {"enabled": False},
                    "scores_by_method": {k: np.asarray(v, dtype=float).tolist() for k, v in scores_by_method.items()},
                    "timings": timings_by_method,
                }
            )

            # Correlation sanity: on clean scenario, IF vs IFC should correlate strongly.
            if str(scen_name).lower() in {"clean", "no_poison", "baseline"} or (not poisoned_group_ids):
                if "IF" in scores_by_method:
                    for spec in method_specs:
                        key = f"IFC::{spec.name}"
                        a = np.asarray(scores_by_method["IF"], dtype=float)
                        b = np.asarray(scores_by_method[key], dtype=float)
                        corr_s = float(spearman_corr(a, b))
                        corr_p = float(pearson_corr(a, b))
                        correlation_reports.append(
                            {
                                "level": "group",
                                "scenario": scen_name,
                                "scenario_index": int(scenario_idx),
                                "lambda": float(lam),
                                "lambda_index": int(lam_idx),
                                "a": "IF",
                                "b": key,
                                "spearman": corr_s,
                                "pearson": corr_p,
                            }
                        )

            # -----------------------
            # Optional: group-level repair benchmark (remove top-k groups)
            # -----------------------
            retr_cfg = dict(cfg.get("retraining_benchmark", {}))
            enabled_flag = retr_cfg.get("enabled", None)
            enabled = bool(enabled_flag) if enabled_flag is not None else bool(scenario_has_mutation)
            if enabled:
                if val_idx.size == 0:
                    log.warning("Scenario %s: group repair requested but val split empty; skipping.", scen_name)
                else:
                    k_groups_list = retr_cfg.get("k_groups", retr_cfg.get("budgets", [0, 1, 3, 5, 10]))
                    k_groups_list = [int(x) for x in k_groups_list]
                    metric_name = str(retr_cfg.get("metric", "accuracy" if base_labels_np is not None else "loss"))
                    metrics_eval = {"accuracy": _accuracy_metric} if metric_name.lower() == "accuracy" else None
                    base_eval = eval_on_indices(model, dataset_clean, val_idx, loss_fn, metrics=metrics_eval)
                    base_metric = float(base_eval.get(metric_name, base_eval.get("loss", float("nan"))))

                    also_eval_backdoor = bool(retr_cfg.get("also_eval_backdoor", False)) and bool(backdoor_enabled)
                    bd_split = str(retr_cfg.get("backdoor_eval_split", "test")).lower()
                    bd_loader_base = test_loader_bd_target if bd_split == "test" else val_loader_bd_target
                    base_asr = float("nan")
                    if also_eval_backdoor and bd_loader_base is not None:
                        base_asr = float(_eval_loss_acc_over_loader(model, bd_loader_base, loss_fn, target_device).get("accuracy", float("nan")))

                    for mname, scores in scores_by_method.items():
                        if mname == "Random":
                            continue
                        # Higher-is-better: poison detection assumes high scores are "more suspicious".
                        higher_is_better = bool(retr_cfg.get("higher_is_better", {}).get(mname, True))

                        # By default, align with poison detection's chosen sign so the
                        # retraining benchmark is consistent with reported AUROC/AUPRC.
                        use_chosen_sign = bool(retr_cfg.get("use_chosen_sign", True))
                        cs = float(chosen_sign_by_method.get(str(mname), 1.0))
                        scores_sel = np.asarray(scores, dtype=float)
                        if use_chosen_sign and cs in (-1.0, 1.0):
                            scores_sel = scores_sel * cs
                        for k_groups in k_groups_list:
                            k_groups = max(0, min(int(k_groups), G_eff))
                            if k_groups == 0:
                                removed_local = np.empty(0, dtype=np.int64)
                            else:
                                top_groups = select_indices_from_scores(scores_sel, k_groups, higher_is_better=higher_is_better)
                                removed: List[int] = []
                                for gg in top_groups.tolist():
                                    removed.extend(int(x) for x in np.asarray(groups_train_local[int(gg)], dtype=np.int64).tolist())
                                removed_local = np.unique(np.asarray(removed, dtype=np.int64))
                                removed_local.sort()

                            keep_mask = np.ones(n_train, dtype=bool)
                            if removed_local.size:
                                keep_mask[removed_local] = False
                            keep_local = np.where(keep_mask)[0]

                            model_r = copy.deepcopy(base_model)
                            model_r.load_state_dict(init_state, strict=True)
                            if device is not None:
                                model_r = model_r.to(device)
                            set_seed(int(seed + scenario_idx * 1000 + lam_idx * 100 + k_groups))
                            t_r0 = time.time()
                            train_model(
                                model_r,
                                Subset(train_dataset, keep_local.tolist()),
                                loss_fn,
                                epochs=int(train_cfg.get("epochs", 1)),
                                lr=float(train_cfg.get("lr", 1e-2)),
                                weight_decay=float(train_cfg.get("weight_decay", 0.0)),
                                batch_size=int(train_cfg.get("batch_size", batch_size)),
                                num_workers=int(train_cfg.get("num_workers", 0)),
                                max_grad_norm=max_grad_norm,
                                stop_on_nonfinite=True,
                                logger=log,
                            )
                            time_retrain = float(time.time() - t_r0)
                            model_r.eval()
                            ev = eval_on_indices(model_r, dataset_clean, val_idx, loss_fn, metrics=metrics_eval)
                            val_metric = float(ev.get(metric_name, ev.get("loss", float("nan"))))
                            imp = _metric_improvement(metric_name, base_metric, val_metric)

                            asr_after = float("nan")
                            asr_improvement = float("nan")
                            if also_eval_backdoor and bd_loader_base is not None:
                                asr_after = float(_eval_loss_acc_over_loader(model_r, bd_loader_base, loss_fn, target_device).get("accuracy", float("nan")))
                                if math.isfinite(base_asr) and math.isfinite(asr_after):
                                    asr_improvement = float(base_asr) - float(asr_after)
                            retraining_reports.append(
                                {
                                    "removal_unit": "group",
                                    "level": "group",
                                    "scenario": scen_name,
                                    "scenario_index": int(scenario_idx),
                                    "lambda": float(lam),
                                    "lambda_index": int(lam_idx),
                                    "method": str(mname),
                                    "metric": metric_name,
                                    "k_groups": int(k_groups),
                                    "n_removed_points": int(removed_local.size),
                                    "base_val_metric": float(base_metric),
                                    "val_metric": float(val_metric),
                                    "improvement": float(imp),
                                    "asr_split": str(bd_split) if also_eval_backdoor else "",
                                    "asr_base": float(base_asr) if also_eval_backdoor else float("nan"),
                                    "asr_after": float(asr_after),
                                    "asr_improvement": float(asr_improvement),
                                    "runtime_sec": float(time_retrain),
                                }
                            )

            log.info(
                "Scenario %s: group benchmark done at λ=%.2e (G=%d) total=%.2fs IF_solve=%.2fs gS=%.2fs gV=%.2fs",
                scen_name,
                lam,
                int(G_eff),
                float(timings_by_method.get("group_benchmark_total_sec", float("nan"))),
                float(timings_by_method.get("IF_inverse_hvp_sec", 0.0)),
                float(timings_by_method.get("compute_gS_sec", 0.0)),
                float(timings_by_method.get("compute_gV_sec", 0.0)),
            )

            # Console-friendly baseline timing summary (prep + mean-per-query).
            try:
                parts: List[str] = []
                for mname in sorted(set(list(prep_sec_by_method.keys()) + list(per_query_sec_mean_by_method.keys()))):
                    prep_s = prep_sec_by_method.get(mname, float("nan"))
                    pq_s = per_query_sec_mean_by_method.get(mname, float("nan"))
                    if math.isfinite(prep_s) or math.isfinite(pq_s):
                        parts.append(f"{mname}: prep={prep_s:.3f}s pq={pq_s*1e3:.2f}ms")
                if parts:
                    log.info("Scenario %s: baseline timings at λ=%.2e (q=%d): %s", scen_name, lam, int(q_count), " | ".join(parts))
            except Exception:
                pass

        scenario_reports.append(env_meta)
        summaries = _summaries_from_rows(all_rows)
        _write_outputs(
            path,
            all_rows,
            summaries,
            {
                "scenarios": scenario_reports,
                "correlations": correlation_reports,
                "poisoning": poisoning_reports,
                "baseline_score_diagnostics": baseline_score_diagnostics,
                "poisoning_sanity": poisoning_sanity,
                "retraining": retraining_reports,
                "group_scores": group_scores_records,
            },
        )
        log.info("Scenario %s complete. Rows so far: %d", scen_name, len(all_rows))

    runtime = time.time() - t_start
    summary = {
        "runtime_sec": runtime,
        "n_rows": len(all_rows),
        "n_scenarios": len(scenarios),
    }
    summaries = _summaries_from_rows(all_rows)
    diagnostics = {
        "scenarios": scenario_reports,
        "summaries": summaries,
        "correlations": correlation_reports,
        "poisoning": poisoning_reports,
        "baseline_score_diagnostics": baseline_score_diagnostics,
        "poisoning_sanity": poisoning_sanity,
        "retraining": retraining_reports,
        "runtime_sec": runtime,
        "output_dir": str(path),
    }
    _log_structured(log, "poisoning_diagnostics", diagnostics)
    with open(path / "run_summary.json", "w") as f:
        json.dump(_serialize({"summary": summary, "diagnostics": diagnostics}), f, indent=2)
    log.info("Paper experiments finished in %.2fs (rows=%d)", runtime, len(all_rows))

    return {
        "table": all_rows,
        "summary": summary,
        "settings": summaries,
        "diagnostics": diagnostics,
    }
