from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple, Union
import copy
import csv
import json
import logging
import math
import time
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset

from ..evals.metrics import pearson_corr, spearman_corr
from ..estimators.if_compressed import IFCompressedCfg, build_ifc_cache, delta_ifc_from_cache
from ..estimators.baselines import PBRFFoldBaseline
from ..estimators.helpers import BaseObjective, CGInfluenceModule, make_folds, eval_on_indices, flatten_params, set_params_from_vector
from ..utils.logging import setup_logging
from ..utils.seed import set_seed
from ..utils.train import train_model

from .run_ablation import (
    _extract_labels,
    _accuracy_metric,
    _grad_mean_over_indices,
    _log_structured,
    _serialize,
)


class _SimpleObjective(BaseObjective):
    def __init__(self, loss_fn: nn.Module, weight_decay: float = 0.0):
        self.loss_fn = loss_fn
        self.weight_decay = 0.0

    def train_outputs(self, model: nn.Module, batch):
        x, _ = batch
        return model(x)

    def train_loss_on_outputs(self, outputs: torch.Tensor, batch):
        _, y = batch
        return self.loss_fn(outputs, y)

    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        return self.weight_decay * torch.sum(params ** 2)

    def test_loss(self, model: nn.Module, params: torch.Tensor, batch):
        x, y = batch
        return self.loss_fn(model(x), y)


def _select_eval_fold_ids(
    pred_if: Dict[int, float],
    pred_ifc: Optional[Dict[int, float]],
    *,
    n_random: int,
    n_top_if: int,
    n_top_ifc: int,
    budget: int,
    seed: int,
    keep_extremes: bool = True,
) -> List[int]:
    """Union of top harmful (largest) + random, capped to budget."""
    keys = sorted(pred_if.keys())
    if not keys:
        return []

    # top harmful = largest predicted Δloss
    top_if = sorted(keys, key=lambda k: pred_if[k], reverse=True)[: max(0, int(n_top_if))]

    top_ifc_list: List[int] = []
    if pred_ifc is not None and len(pred_ifc) > 0:
        keys_ifc = sorted(set(keys) & set(pred_ifc.keys()))
        top_ifc_list = sorted(keys_ifc, key=lambda k: pred_ifc[k], reverse=True)[: max(0, int(n_top_ifc))]

    rng = np.random.default_rng(int(seed))
    n_rand = min(int(n_random), len(keys))
    rand = rng.choice(np.asarray(keys, dtype=int), size=n_rand, replace=False).tolist()

    # union in order: extremes first, then random
    union: List[int] = []
    seen = set()
    for k in (top_if + top_ifc_list + rand):
        k = int(k)
        if k not in seen:
            union.append(k)
            seen.add(k)

    if len(union) <= int(budget):
        return union

    budget = int(budget)

    extremes: List[int] = []
    seen = set()
    for k in (top_if + top_ifc_list):
        k = int(k)
        if k not in seen:
            extremes.append(k)
            seen.add(k)

    if len(extremes) >= budget:
        return extremes[:budget]

    remaining = [k for k in union if k not in set(extremes)]
    need = budget - len(extremes)
    return extremes + remaining
    if len(remaining) <= need:
        return extremes + remaining

    fill = rng.choice(np.asarray(remaining, dtype=int), size=need, replace=False).tolist()
    return extremes + [int(k) for k in fill]


def _choose_ifc_key_for_selection(
    *,
    lam: float,
    ifc_cache_by_setting: Dict[Tuple[float, int, int, int], Optional[dict]],
    select_ifc: Optional[Dict[str, int]],
) -> Optional[Tuple[float, int, int, int]]:
    """
    Pick one IFC cache config (λ, jl_dim, C, U) to build the evaluation subset Q.
    If select_ifc is provided, we try to match it; otherwise pick a sensible default
    (largest C, prefer U=0, then smallest jl_dim).
    """
    cands = [k for k in ifc_cache_by_setting.keys() if abs(float(k[0]) - float(lam)) < 1e-12]
    if not cands:
        return None

    if select_ifc is not None:
        jl = int(select_ifc.get("jl_dim", -1))
        C = int(select_ifc.get("clusters", -1))
        U = int(select_ifc.get("umap_dim", -1))
        filtered = []
        for k in cands:
            _, jl_k, C_k, U_k = k
            if (jl < 0 or jl_k == jl) and (C < 0 or C_k == C) and (U < 0 or U_k == U):
                filtered.append(k)
        if filtered:
            # deterministic: smallest jl among ties, largest C, prefer U=0
            filtered.sort(key=lambda t: (t[3] != 0, -t[2], t[1]))
            return filtered[0]

    # default: prefer U=0, then largest C, then smallest jl_dim
    cands.sort(key=lambda t: (t[3] != 0, -t[2], t[1]))
    return cands[0]


def run_kfold_cv_compare(
    cfg: Dict[str, Any],
    logger: Optional[logging.Logger] = None,
    device: Optional[Union[str, torch.device]] = None,
) -> Dict[str, Any]:
    """Run k-fold retraining baselines (optionally on a subset of folds) and compare IF/IFC.

    Baseline per fold: train a fresh model on keep-set (all data except fold S), eval loss on S.
    Approx per fold: starting from the full-data trained model, apply IF/IFC delta-theta for removing S, eval loss on S.

    If cfg["gt_eval"]["enable"]=True, we DO NOT retrain all folds. Instead, we:
      - predict Δloss for all folds using IF and a selected IFC cache setting,
      - build an evaluation set Q = top(IF) ∪ top(IFC) ∪ random (capped by budget),
      - retrain only folds in Q (baseline), and evaluate all settings on those folds.
    """
    outdir = cfg.get("outdir", "runs")
    name = cfg.get("name", "kfold_cv_compare")
    path = Path(outdir) / name
    path.mkdir(parents=True, exist_ok=True)
    log = logger or setup_logging(cfg.get("log_level", "DEBUG"), log_dir=str(path))

    seed = int(cfg.get("seed", 0))
    set_seed(seed)

    base_model: nn.Module = cfg["model"]
    dataset = cfg["dataset"]
    loss_fn: nn.Module = cfg["loss_fn"]
    if device is not None:
        base_model = base_model.to(device)

    model_device = next(base_model.parameters()).device
    N = len(dataset)

    labels = cfg.get("labels", None)
    labels_arr = _extract_labels(dataset, labels)
    is_classifier = labels_arr is not None and np.issubdtype(np.asarray(labels_arr).dtype, np.integer)
    metrics_eval = {"accuracy": _accuracy_metric} if is_classifier else None

    # Keep an initialization snapshot for true CV retraining
    init_state = copy.deepcopy(base_model.state_dict())

    # ---- subset-eval config ----
    gt = cfg.get("gt_eval", {}) or {}
    gt_enable = bool(gt.get("enable", False))
    gt_n_random = int(gt.get("n_random", 20))
    gt_n_top_if = int(gt.get("n_top_if", 20))
    gt_n_top_ifc = int(gt.get("n_top_ifc", 20))
    gt_budget = int(gt.get("budget", 60))
    gt_keep_extremes = bool(gt.get("keep_extremes", True))
    # choose IFC setting for selection (optional)
    gt_select_ifc = gt.get("select_ifc", None)  # e.g. {"jl_dim":32,"clusters":50,"umap_dim":0}
    gt_select_union_over_lambdas = bool(gt.get("union_over_lambdas", True))
    gt_budget_total = int(gt.get("budget_total", gt_budget))  # cap after union over lambdas

    # ---- fold PBRF config (Option B) ----
    pbrf = cfg.get("pbrf", {}) or {}
    pbrf_enable = bool(pbrf.get("enable", False))
    pbrf_retrain_epochs = int(pbrf.get("retrain_epochs", 50))
    pbrf_lr = float(pbrf.get("lr", 1e-2))
    pbrf_optimizer = str(pbrf.get("optimizer", "sgd"))
    pbrf_momentum = float(pbrf.get("momentum", 0.9))
    pbrf_lambda_damp = float(pbrf.get("lambda_damp", 1e-3))
    pbrf_epsilon = pbrf.get("epsilon", None)
    pbrf_epsilon_mode = str(pbrf.get("epsilon_mode", "m_over_N"))

    # Train full-data model once (theta-hat)
    train_full = cfg.get("train_full", cfg.get("train", {"epochs": 5, "lr": 1e-2, "batch_size": 128}))
    t_train0 = time.time()
    train_model(
        base_model,
        dataset,
        loss_fn,
        epochs=int(train_full.get("epochs", 5)),
        lr=float(train_full.get("lr", 1e-2)),
        weight_decay=float(train_full.get("weight_decay", 0.0)),
        batch_size=int(train_full.get("batch_size", 128)),
        num_workers=int(train_full.get("num_workers", 0)),
        logger=log,
    )
    time_train_full = time.time() - t_train0
    base_model.eval()

    _log_structured(
        log,
        "train_full_done",
        {
            "time_sec": float(time_train_full),
            "train_full": dict(train_full) if isinstance(train_full, dict) else train_full,
        },
    )

    # Shared full-data loader for IFC + CG module
    batch_size = int(cfg.get("batch_size", train_full.get("batch_size", 128)))
    num_workers = int(cfg.get("num_workers", train_full.get("num_workers", 0)))
    all_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=bool(torch.cuda.is_available()),
        persistent_workers=bool(num_workers > 0),
        num_workers=num_workers,
    )

    # Fold PBRF shared state (prepared at theta^s)
    pbrf_fold: Optional[PBRFFoldBaseline] = None
    if pbrf_enable:
        pbrf_fold = PBRFFoldBaseline(
            lambda_damp=float(pbrf_lambda_damp),
            retrain_epochs=int(pbrf_retrain_epochs),
            lr=float(pbrf_lr),
            optimizer=str(pbrf_optimizer),
            momentum=float(pbrf_momentum),
            batch_size=int(batch_size),
            num_workers=int(num_workers),
            epsilon=float(pbrf_epsilon) if pbrf_epsilon is not None else None,
            epsilon_mode=str(pbrf_epsilon_mode),
        )
        t_pbrf0 = time.time()
        pbrf_fold.prepare_shared(model=base_model, train_loader=all_loader, loss_fn=loss_fn, device=model_device)
        _log_structured(
            log,
            "pbrf_prepare_done",
            {
                "enabled": True,
                "time_sec": float(time.time() - t_pbrf0),
                "retrain_epochs": int(pbrf_retrain_epochs),
                "lr": float(pbrf_lr),
                "optimizer": str(pbrf_optimizer),
                "momentum": float(pbrf_momentum),
                "lambda_damp": float(pbrf_lambda_damp),
                "epsilon": float(pbrf_epsilon) if pbrf_epsilon is not None else None,
                "epsilon_mode": str(pbrf_epsilon_mode),
            },
        )

    # Grids
    Ks = list(cfg.get("Ks", [5]))
    damping_grid = list(cfg.get("damping_grid", cfg.get("damping", [1e-3])))
    jl_dims = list(cfg.get("jl_dims", cfg.get("jl", [128])))
    clusters_grid = list(cfg.get("clusters", [50]))
    cl_jl_grid = list(cfg.get("cl_jl_grid", [(10,32), (10,256), (50, 32), (100, 256), (500, 32), (500, 64)]))
    use_umap_default = bool(cfg.get("use_umap", True))
    umap_dims = list(cfg.get("umap_dims", [25]))
    skip_ifc = bool(cfg.get("skip_ifc", False))

    cg_tol = float(cfg.get("cg_tol", cfg.get("tol", 1e-6)))
    cg_max = int(cfg.get("cg_max_iters", cfg.get("cg_iters", 1000)))
    cg_tol = min(cg_tol, 1e-6)
    cg_max = max(cg_max, 300)
    fisher = bool(cfg.get("fisher", True))

    gamma_fixed = float(cfg.get("gamma_fixed", 1.0))
    recourse_steps = int(cfg.get("recourse_steps", 0))

    # Output containers
    per_setting_rows: Dict[Tuple[float, int, int, int, int], List[dict]] = defaultdict(list)
    setting_summaries: Dict[Tuple[float, int, int, int, int], Dict[str, Any]] = {}
    all_rows_nested: Dict[Tuple[float, int, int, int, int], List[dict]] = defaultdict(list)

    responses_dir = path / "responses"
    responses_dir.mkdir(parents=True, exist_ok=True)

    _log_structured(
        log,
        "kfold_cv_compare_config",
        {
            "outdir": str(outdir),
            "name": str(name),
            "output_dir": str(path),
            "responses_dir": str(responses_dir),
            "seed": int(seed),
            "N": int(N),
            "device": str(device) if device is not None else None,
            "is_classifier": bool(is_classifier),
            "Ks": [int(k) for k in Ks],
            "damping_grid": [float(l) for l in damping_grid],
            "jl_dims": [int(j) for j in jl_dims],
            "clusters": [int(c) for c in clusters_grid],
            "cl_jl_grid": [[int(t[0]), int(t[1])] for t in cl_jl_grid],
            "use_umap": bool(use_umap_default),
            "umap_dims": [int(u) for u in umap_dims],
            "skip_ifc": bool(skip_ifc),
            "cg_tol": float(cg_tol),
            "cg_max_iters": int(cg_max),
            "fisher": bool(fisher),
            "batch_size": int(batch_size),
            "num_workers": int(num_workers),
            "train_full": dict(train_full) if isinstance(train_full, dict) else train_full,
            "train_cv": dict(cfg.get("train_cv")) if isinstance(cfg.get("train_cv"), dict) else cfg.get("train_cv"),
            "recourse_steps": int(recourse_steps),
            "gt_eval": {
                "enable": gt_enable,
                "n_random": gt_n_random,
                "n_top_if": gt_n_top_if,
                "n_top_ifc": gt_n_top_ifc,
                "budget": gt_budget,
                "budget_total": gt_budget_total,
                "keep_extremes": gt_keep_extremes,
                "select_ifc": gt_select_ifc,
                "union_over_lambdas": gt_select_union_over_lambdas,
            },
            "pbrf": {
                "enable": bool(pbrf_enable),
                "retrain_epochs": int(pbrf_retrain_epochs),
                "lr": float(pbrf_lr),
                "optimizer": str(pbrf_optimizer),
                "momentum": float(pbrf_momentum),
                "lambda_damp": float(pbrf_lambda_damp),
                "epsilon": float(pbrf_epsilon) if pbrf_epsilon is not None else None,
                "epsilon_mode": str(pbrf_epsilon_mode),
            },
        },
    )

    # Precompute IFC caches once per (lambda, jl, C, U)
    ifc_cache_by_setting: Dict[Tuple[float, int, int, int], Optional[dict]] = {}
    ifc_cache_meta_by_setting: Dict[Tuple[float, int, int, int], Dict[str, Any]] = {}
    if not skip_ifc:
        for lam in damping_grid:
            for C, jl_dim in cl_jl_grid:
                umap_dims_iter = umap_dims if use_umap_default else [0]
                for umap_dim in umap_dims_iter:
                    if int(umap_dim) > int(C):
                        continue
                    use_umap = int(umap_dim) > 0
                    key4 = (float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0)
                    if key4 in ifc_cache_by_setting:
                        continue
                    ifc_cfg = IFCompressedCfg(
                        name="ifc",
                        jl_dim=int(jl_dim),
                        clusters=int(C),
                        damping=float(lam),
                        max_cg_iters=cg_max,
                        tol=cg_tol,
                        fisher=fisher,
                        normalize=True,
                        use_umap=use_umap,
                        umap_n_components=int(umap_dim) if use_umap else 0,
                        seed=seed,
                        collect_diagnostics=bool(cfg.get("ifc_collect_diagnostics", cfg.get("collect_diagnostics", False))),
                    )
                    log.info("Building IFC cache: λ=%.2e jl=%d C=%d U=%d", float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0)
                    t0_cache = time.time()
                    cache = build_ifc_cache(
                        base_model,
                        all_loader,
                        loss_fn,
                        ifc_cfg,
                        logger=log,
                        distortion_pairs=int(cfg.get("distortion_pairs", 2000)),
                    )
                    t_cache = time.time() - t0_cache
                    ifc_cache_by_setting[key4] = cache
                    ifc_cache_meta_by_setting[key4] = {
                        "lambda": float(lam),
                        "jl_dim": int(jl_dim),
                        "clusters": int(C),
                        "umap_dim": int(umap_dim) if use_umap else 0,
                        "build_time": float(t_cache),
                        "timings": cache.get("timings", {}) if isinstance(cache, dict) else {},
                        "cluster_stats": cache.get("cluster_stats", {}) if isinstance(cache, dict) else {},
                    }
                    _log_structured(log, "ifc_cache_metadata", ifc_cache_meta_by_setting[key4])
    else:
        _log_structured(log, "ifc_cache_skipped", {"skip_ifc": True})

    # CG modules per lambda
    objective = _SimpleObjective(loss_fn, weight_decay=float(train_full.get("weight_decay", 0.0)))
    mod_by_lambda: Dict[float, CGInfluenceModule] = {}
    for lam in damping_grid:
        mod_by_lambda[float(lam)] = CGInfluenceModule(
            model=base_model,
            objective=objective,
            train_loader=all_loader,
            test_loader=all_loader,
            device=model_device,
            damp=float(lam),
            gnh=fisher,
            tol=cg_tol,
            maxiter=cg_max,
        )

    theta_hat = flatten_params(base_model).detach().clone()

    def _eval_with_delta(delta: torch.Tensor, idxs: np.ndarray) -> Dict[str, Any]:
        set_params_from_vector(base_model, theta_hat + delta)
        out = eval_on_indices(base_model, dataset, idxs, loss_fn, metrics=metrics_eval)
        set_params_from_vector(base_model, theta_hat)
        return out

    train_cv = cfg.get("train_cv", cfg.get("train", train_full))

    all_rows: List[dict] = []
    t0 = time.time()

    for K in Ks:
        K = int(K)
        if K <= 1 or K > N:
            log.warning("Skipping invalid K=%d", K)
            continue

        log.info("Starting K-fold compare: K=%d", int(K))
        tK0 = time.time()

        folds = make_folds(N, K, stratify=labels, seed=seed)

        # Precompute full-model loss on each fold (needed for Δloss selection + logging)
        loss_full_by_fold: Dict[int, float] = {}
        for f, S in enumerate(folds):
            S = np.asarray(S, dtype=np.int64)
            if S.size == 0:
                continue
            loss_full_by_fold[int(f)] = float(eval_on_indices(base_model, dataset, S, loss_fn, metrics=metrics_eval)["loss"])

        # --------- IF: compute for ALL folds (no cv_baseline dependency) ----------
        IF_by_lambda: Dict[float, Dict[int, torch.Tensor]] = {}
        IF_stats_by_lambda: Dict[float, Dict[int, Dict[str, Any]]] = {}
        gS_cache_by_lambda: Dict[float, Dict[int, torch.Tensor]] = {}
        for lam in damping_grid:
            lam = float(lam)
            mod = mod_by_lambda[lam]
            IF_by_lambda[lam] = {}
            IF_stats_by_lambda[lam] = {}
            gS_cache_by_lambda[lam] = {}

            t_if0 = time.time()
            for f, S in enumerate(folds):
                f = int(f)
                S = np.asarray(S, dtype=np.int64)
                if S.size == 0:
                    continue

                gS = _grad_mean_over_indices(base_model, dataset, loss_fn, S).detach().to(model_device)
                t_cg0 = time.time()
                sol = mod.inverse_hvp(gS)
                t_cg = time.time() - t_cg0

                v = sol.get("ihvp") if isinstance(sol, dict) else None
                if v is None:
                    v = torch.zeros_like(gS)
                iters = int(sol.get("iterations", 0)) if isinstance(sol, dict) else 0
                resid = float(sol.get("final_residual", float("nan"))) if isinstance(sol, dict) else float("nan")

                delta = (len(S) / N) * v.detach()
                IF_by_lambda[lam][f] = delta.detach().cpu()
                IF_stats_by_lambda[lam][f] = {"cg_time": float(t_cg), "iters": int(iters), "residual": float(resid)}
                gS_cache_by_lambda[lam][f] = gS.detach().cpu()

            t_if = time.time() - t_if0
            cg_times = [float(v.get("cg_time", float("nan"))) for v in IF_stats_by_lambda[lam].values()]
            iters = [float(v.get("iters", float("nan"))) for v in IF_stats_by_lambda[lam].values()]
            _log_structured(
                log,
                "if_done",
                {
                    "K": int(K),
                    "lambda": float(lam),
                    "time_sec": float(t_if),
                    "n_folds": int(len(IF_stats_by_lambda[lam])),
                    "cg_time_mean": float(np.nanmean(cg_times)) if cg_times else float("nan"),
                    "iters_mean": float(np.nanmean(iters)) if iters else float("nan"),
                },
            )

            # Save IF responses
            try:
                if_path = responses_dir / f"IF_K{int(K)}_lam{float(lam):.2e}.pt"
                torch.save(
                    {
                        "K": int(K),
                        "lambda": float(lam),
                        "N": int(N),
                        "deltas": {int(f): t.detach().cpu() for f, t in IF_by_lambda[lam].items()},
                        "stats": IF_stats_by_lambda[lam],
                    },
                    if_path,
                )
            except Exception as e:
                log.warning("Failed to save IF responses for K=%d λ=%.2e: %s", int(K), float(lam), str(e))

        # --------- Select evaluation subset Q and retrain baseline only on Q ----------
        # cv_baseline maps fold_id -> baseline retrain metrics (only for selected folds if gt_enable)
        cv_baseline: Dict[int, Dict[str, Any]] = {}

        if gt_enable:
            if skip_ifc:
                log.warning("gt_eval enabled but skip_ifc=True; selection will use IF+random only.")
            eval_folds_union: List[int] = []

            def _pred_summary(d: Dict[int, float]) -> Dict[str, Any]:
                if not d:
                    return {"n": 0, "mean": float("nan"), "std": float("nan"), "min": float("nan"), "max": float("nan")}
                v = np.asarray(list(d.values()), dtype=float)
                return {
                    "n": int(v.size),
                    "mean": float(np.nanmean(v)),
                    "std": float(np.nanstd(v)),
                    "min": float(np.nanmin(v)),
                    "max": float(np.nanmax(v)),
                }

            def _topk(d: Dict[int, float], k: int) -> List[Dict[str, Any]]:
                if not d or k <= 0:
                    return []
                items = sorted(d.items(), key=lambda kv: float(kv[1]), reverse=True)[: int(k)]
                return [{"fold": int(f), "pred_dloss": float(v)} for f, v in items]

            # union over lambdas (default) or just first lambda
            lambdas_for_sel = [float(l) for l in damping_grid] if gt_select_union_over_lambdas else [float(damping_grid[0])]
            sel_debug: Dict[str, Any] = {"per_lambda": []}

            for lam in lambdas_for_sel:
                t_sel0 = time.time()
                pred_if: Dict[int, float] = {}
                pred_ifc: Optional[Dict[int, float]] = {} if (not skip_ifc) else None

                # choose a single IFC cache config for selection
                key4_sel = None
                cache_sel = None
                if not skip_ifc:
                    key4_sel = _choose_ifc_key_for_selection(
                        lam=lam, ifc_cache_by_setting=ifc_cache_by_setting, select_ifc=gt_select_ifc
                    )
                    cache_sel = ifc_cache_by_setting.get(key4_sel) if key4_sel is not None else None
                    if cache_sel is None:
                        if gt_select_ifc is not None:
                            log.warning(
                                "gt_eval selection requested IFC=%s but no matching cache was found for λ=%.2e; falling back to IF-only selection.",
                                str(gt_select_ifc),
                                float(lam),
                            )
                        pred_ifc = None

                # predicted Δloss = loss(pred under delta) - loss(full model)
                for f, S in enumerate(folds):
                    f = int(f)
                    S = np.asarray(S, dtype=np.int64)
                    if S.size == 0 or f not in loss_full_by_fold:
                        continue
                    loss_full = float(loss_full_by_fold[f])

                    d_if = IF_by_lambda[lam].get(f)
                    if d_if is None:
                        continue
                    loss_if = float(_eval_with_delta(d_if.to(model_device), S)["loss"])
                    pred_if[f] = loss_if - loss_full

                    if pred_ifc is not None and cache_sel is not None:
                        out_ifc = delta_ifc_from_cache(cache_sel, S.tolist(), device=model_device)
                        d_ifc_sel = out_ifc["delta_theta"]
                        loss_ifc_sel = float(_eval_with_delta(d_ifc_sel, S)["loss"])
                        pred_ifc[f] = loss_ifc_sel - loss_full  # type: ignore[index]

                _log_structured(
                    log,
                    "gt_eval_predictions",
                    {
                        "K": int(K),
                        "lambda": float(lam),
                        "ifc_key_selected": tuple(key4_sel) if key4_sel is not None else None,
                        "using_ifc": bool(pred_ifc is not None and cache_sel is not None),
                        "pred_if": _pred_summary(pred_if),
                        "pred_ifc": _pred_summary(pred_ifc) if isinstance(pred_ifc, dict) else None,
                        "top_if": _topk(pred_if, k=min(10, int(gt_n_top_if))),
                        "top_ifc": _topk(pred_ifc, k=min(10, int(gt_n_top_ifc))) if isinstance(pred_ifc, dict) else None,
                        "selection_params": {
                            "n_random": int(gt_n_random),
                            "n_top_if": int(gt_n_top_if),
                            "n_top_ifc": int(gt_n_top_ifc),
                            "budget": int(gt_budget),
                            "keep_extremes": bool(gt_keep_extremes),
                        },
                    },
                )

                q_ids = _select_eval_fold_ids(
                    pred_if=pred_if,
                    pred_ifc=pred_ifc,
                    n_random=gt_n_random,
                    n_top_if=gt_n_top_if,
                    n_top_ifc=gt_n_top_ifc,
                    budget=gt_budget,
                    seed=seed + int(1e6 * lam) % 10_000_000,
                    keep_extremes=gt_keep_extremes,
                )

                sel_debug["per_lambda"].append(
                    {
                        "lambda": float(lam),
                        "key4_sel": key4_sel,
                        "n_selected": int(len(q_ids)),
                        "selected_folds": [int(x) for x in q_ids],
                        "time_sec": float(time.time() - t_sel0),
                    }
                )
                eval_folds_union.extend(q_ids)

            # dedupe + cap total budget
            eval_folds = []
            seen = set()
            for f in eval_folds_union:
                f = int(f)
                if f not in seen:
                    eval_folds.append(f)
                    seen.add(f)

            if len(eval_folds) > gt_budget_total:
                # cap by keeping extremes from the UNION: take highest predicted IF folds first, then IFC, then random
                # simplest defensible cap: keep first gt_budget_total in the already-extremes-first construction
                eval_folds = eval_folds[:gt_budget_total]

            _log_structured(
                log,
                "gt_eval_selection",
                {
                    "K": int(K),
                    "enabled": True,
                    "budget_total": int(gt_budget_total),
                    "selected_folds": [int(f) for f in eval_folds],
                    **sel_debug,
                },
            )

            # retrain baseline only on selected folds
            t_baseline0 = time.time()
            for f in eval_folds:
                S = np.asarray(folds[int(f)], dtype=np.int64)
                if S.size == 0:
                    continue
                keep_mask = np.ones(N, dtype=bool)
                keep_mask[S] = False
                keep_idx = np.where(keep_mask)[0]
                if keep_idx.size == 0:
                    continue

                model_cv = copy.deepcopy(base_model)
                model_cv.load_state_dict(init_state)
                model_cv.to(model_device)
                model_cv.train()
                t_cv0 = time.time()
                train_model(
                    model_cv,
                    Subset(dataset, keep_idx.tolist()),
                    loss_fn,
                    epochs=int(train_cv.get("epochs", 5)),
                    lr=float(train_cv.get("lr", 1e-2)),
                    weight_decay=float(train_cv.get("weight_decay", 0.0)),
                    batch_size=int(train_cv.get("batch_size", batch_size)),
                    num_workers=int(train_cv.get("num_workers", 0)),
                    logger=None,
                )
                t_cv = time.time() - t_cv0
                model_cv.eval()
                base_cv = eval_on_indices(model_cv, dataset, S, loss_fn, metrics=metrics_eval)

                cv_baseline[int(f)] = {
                    "loss": float(base_cv["loss"]),
                    "acc": float(base_cv.get("accuracy", float("nan"))),
                    "time_train_cv": float(t_cv),
                    "n_keep": int(keep_idx.size),
                    "n_S": int(S.size),
                    "loss_full": float(loss_full_by_fold.get(int(f), float("nan"))),
                    "d_loss_actual": float(base_cv["loss"]) - float(loss_full_by_fold.get(int(f), float("nan"))),
                }

            t_baseline = time.time() - t_baseline0
            _log_structured(
                log,
                "cv_baseline_done",
                {
                    "K": int(K),
                    "n_folds": int(len(cv_baseline)),
                    "subset_eval": True,
                    "time_sec": float(t_baseline),
                    "mean_train_time": float(np.mean([v.get("time_train_cv", float("nan")) for v in cv_baseline.values()])) if cv_baseline else float("nan"),
                },
            )
        else:
            # original behavior: retrain ALL folds
            t_baseline0 = time.time()
            progress_every = max(1, int(math.ceil(K / 5)))
            for f, S in enumerate(folds):
                S = np.asarray(S, dtype=np.int64)
                if S.size == 0:
                    continue
                keep_mask = np.ones(N, dtype=bool)
                keep_mask[S] = False
                keep_idx = np.where(keep_mask)[0]
                if keep_idx.size == 0:
                    continue

                model_cv = copy.deepcopy(base_model)
                model_cv.load_state_dict(init_state)
                model_cv.to(model_device)
                model_cv.train()
                t_cv0 = time.time()
                train_model(
                    model_cv,
                    Subset(dataset, keep_idx.tolist()),
                    loss_fn,
                    epochs=int(train_cv.get("epochs", 5)),
                    lr=float(train_cv.get("lr", 1e-2)),
                    weight_decay=float(train_cv.get("weight_decay", 0.0)),
                    batch_size=int(train_cv.get("batch_size", batch_size)),
                    num_workers=int(train_cv.get("num_workers", 0)),
                    logger=None,
                )
                t_cv = time.time() - t_cv0
                model_cv.eval()
                base_cv = eval_on_indices(model_cv, dataset, S, loss_fn, metrics=metrics_eval)
                cv_baseline[int(f)] = {
                    "loss": float(base_cv["loss"]),
                    "acc": float(base_cv.get("accuracy", float("nan"))),
                    "time_train_cv": float(t_cv),
                    "n_keep": int(keep_idx.size),
                    "n_S": int(S.size),
                    "loss_full": float(loss_full_by_fold.get(int(f), float("nan"))),
                    "d_loss_actual": float(base_cv["loss"]) - float(loss_full_by_fold.get(int(f), float("nan"))),
                }

                if (int(f) == 0) or ((int(f) + 1) % progress_every == 0) or (int(f) + 1 == int(K)):
                    log.info(
                        "CV baseline fold %d/%d: loss=%.6f train=%.2fs n_keep=%d n_S=%d",
                        int(f) + 1,
                        int(K),
                        float(cv_baseline[int(f)]["loss"]),
                        float(t_cv),
                        int(keep_idx.size),
                        int(S.size),
                    )

            t_baseline = time.time() - t_baseline0
            _log_structured(
                log,
                "cv_baseline_done",
                {
                    "K": int(K),
                    "n_folds": int(len(cv_baseline)),
                    "subset_eval": False,
                    "time_sec": float(t_baseline),
                    "mean_train_time": float(np.mean([v.get("time_train_cv", float("nan")) for v in cv_baseline.values()])) if cv_baseline else float("nan"),
                },
            )

        # --------- Fold PBRF (Option B): run only on folds with baseline ----------
        pbrf_by_fold: Dict[int, Dict[str, Any]] = {}
        if pbrf_enable and pbrf_fold is not None:
            fold_ids = sorted([int(f) for f in cv_baseline.keys()])
            t_p0 = time.time()
            for f in fold_ids:
                S = np.asarray(folds[int(f)], dtype=np.int64)
                if S.size == 0:
                    continue

                base_loss = float(loss_full_by_fold.get(int(f), float("nan")))
                t_f0 = time.time()
                info = pbrf_fold.fit_subset(model=base_model, dataset=dataset, subset_indices=S)
                r_p = eval_on_indices(base_model, dataset, S, loss_fn, metrics=metrics_eval)
                loss_p = float(r_p["loss"])
                acc_p = float(r_p.get("accuracy", float("nan")))
                pbrf_fold.restore_theta_s(model=base_model)

                pbrf_by_fold[int(f)] = {
                    "loss": float(loss_p),
                    "acc": float(acc_p),
                    "loss_full": float(base_loss),
                    "d_loss_pbrf": float(loss_p - base_loss) if np.isfinite(base_loss) else float("nan"),
                    "time_sec": float(time.time() - t_f0),
                    "eps": float(info.get("eps", float("nan"))),
                    "subset_size": int(info.get("subset_size", int(S.size))),
                }

                _log_structured(
                    log,
                    "pbrf_fold_done",
                    {
                        "K": int(K),
                        "fold": int(f),
                        "subset_size": int(S.size),
                        "eps": float(pbrf_by_fold[int(f)]["eps"]),
                        "loss_full": float(base_loss),
                        "loss_pbrf": float(loss_p),
                        "d_loss_pbrf": float(pbrf_by_fold[int(f)]["d_loss_pbrf"]),
                        "time_sec": float(pbrf_by_fold[int(f)]["time_sec"]),
                    },
                )

            _log_structured(
                log,
                "pbrf_done",
                {
                    "K": int(K),
                    "enabled": True,
                    "n_folds": int(len(pbrf_by_fold)),
                    "time_sec": float(time.time() - t_p0),
                    "mean_time_sec": float(np.mean([v.get("time_sec", float("nan")) for v in pbrf_by_fold.values()])) if pbrf_by_fold else float("nan"),
                },
            )

        # --------- Evaluate comparisons for every (lambda, jl, C, U) on folds with baseline ----------
        theta_hat = flatten_params(base_model).detach().clone()

        for lam in damping_grid:
            lam = float(lam)
            if skip_ifc:
                settings4 = [(lam, 0, 0, 0)]
            else:
                settings4 = [k for k in ifc_cache_by_setting.keys() if abs(float(k[0]) - lam) < 1e-12]

            for key4 in settings4:
                _, jl_dim, C, umap_dim = key4
                use_umap = int(umap_dim) > 0
                ifc_cache = None if skip_ifc else ifc_cache_by_setting.get(key4)
                ifc_meta = ifc_cache_meta_by_setting.get(key4, {"lambda": lam, "jl_dim": jl_dim, "clusters": C, "umap_dim": umap_dim})

                ifc_deltas_to_save: Dict[int, torch.Tensor] = {}

                for f, S in enumerate(folds):
                    f = int(f)
                    S = np.asarray(S, dtype=np.int64)
                    if S.size == 0 or f not in cv_baseline:
                        continue

                    base_cv_loss = float(cv_baseline[f]["loss"])
                    loss_full = float(cv_baseline[f].get("loss_full", float("nan")))

                    d_if = IF_by_lambda[lam].get(f)
                    if d_if is None:
                        continue
                    d_if = d_if.to(model_device)

                    # IF approx loss on S
                    r_if = _eval_with_delta(d_if, S)
                    loss_if = float(r_if["loss"])

                    # IFC delta and loss
                    if skip_ifc or ifc_cache is None:
                        d_ifc = torch.zeros_like(d_if)
                    else:
                        out_ifc = delta_ifc_from_cache(ifc_cache, S.tolist(), device=model_device)
                        d_ifc = out_ifc["delta_theta"]
                        if recourse_steps > 0:
                            mod = mod_by_lambda[lam]
                            theta0 = theta_hat.detach().clone()
                            set_params_from_vector(base_model, theta0 + d_ifc)
                            gS_new = _grad_mean_over_indices(base_model, dataset, loss_fn, S).detach().to(model_device)
                            set_params_from_vector(base_model, theta0)
                            gS_old = gS_cache_by_lambda[lam].get(f)
                            if gS_old is not None:
                                corr_sol = mod.inverse_hvp((gS_new - gS_old.to(model_device)).detach())
                                corr = corr_sol.get("ihvp") if isinstance(corr_sol, dict) else None
                                if corr is None:
                                    corr = torch.zeros_like(d_ifc)
                                d_ifc = d_ifc + (len(S) / N) * corr.detach()

                    ifc_deltas_to_save[f] = d_ifc.detach().cpu()
                    r_ifc = _eval_with_delta(d_ifc, S)
                    loss_ifc = float(r_ifc["loss"])

                    # Δloss relative to full model (useful for "harmfulness" + calibration)
                    dloss_if = loss_if - loss_full if np.isfinite(loss_full) else float("nan")
                    dloss_ifc = loss_ifc - loss_full if np.isfinite(loss_full) else float("nan")
                    dloss_actual = float(cv_baseline[f].get("d_loss_actual", float("nan")))

                    pbrf_row = pbrf_by_fold.get(int(f), {}) if pbrf_enable else {}
                    loss_pbrf = float(pbrf_row.get("loss", float("nan")))

                    row = {
                        "fold": f,
                        "K": int(K),
                        "lambda": float(lam),
                        "jl_dim": int(jl_dim),
                        "clusters": int(C),
                        "umap_dim": int(umap_dim) if use_umap else 0,
                        "recourse": int(recourse_steps),
                        "loss_full": float(loss_full),
                        "baseline_cv_loss": float(base_cv_loss),
                        "loss_if": float(loss_if),
                        "loss_ifc": float(loss_ifc),
                        "d_loss_actual": float(dloss_actual),
                        "d_loss_if": float(dloss_if),
                        "d_loss_ifc": float(dloss_ifc),
                        "err_if": float(loss_if - base_cv_loss),
                        "err_ifc": float(loss_ifc - base_cv_loss),
                        "abs_err_if": float(abs(loss_if - base_cv_loss)),
                        "abs_err_ifc": float(abs(loss_ifc - base_cv_loss)),
                        "loss_pbrf": float(loss_pbrf) if pbrf_enable else float("nan"),
                        "d_loss_pbrf": float(pbrf_row.get("d_loss_pbrf", float("nan"))) if pbrf_enable else float("nan"),
                        "err_pbrf": float(loss_pbrf - base_cv_loss) if (pbrf_enable and np.isfinite(loss_pbrf)) else float("nan"),
                        "abs_err_pbrf": float(abs(loss_pbrf - base_cv_loss)) if (pbrf_enable and np.isfinite(loss_pbrf)) else float("nan"),
                        "acc_cv": float(cv_baseline[f].get("acc", float("nan"))),
                        "acc_if": float(r_if.get("accuracy", float("nan"))),
                        "acc_ifc": float(r_ifc.get("accuracy", float("nan"))),
                        "acc_pbrf": float(pbrf_row.get("acc", float("nan"))) if pbrf_enable else float("nan"),
                        "gamma": float(gamma_fixed),
                        "||d_if||": float(d_if.norm().item()),
                        "||d_ifc||": float(d_ifc.norm().item()),
                        "if_stats": IF_stats_by_lambda[lam].get(f, {}),
                        "ifc_cache_metadata": ifc_meta,
                        "gt_subset_eval": bool(gt_enable),
                        "pbrf_enabled": bool(pbrf_enable),
                    }

                    key5 = (float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0, int(K))
                    per_setting_rows[key5].append(row)
                    all_rows_nested[key5].append(row)
                    all_rows.append(row)

                # Save IFC responses per setting
                try:
                    ifc_path = responses_dir / (
                        f"IFC_K{int(K)}_lam{float(lam):.2e}_jl{int(jl_dim)}_C{int(C)}_U{int(umap_dim) if use_umap else 0}.pt"
                    )
                    torch.save(
                        {
                            "K": int(K),
                            "lambda": float(lam),
                            "jl_dim": int(jl_dim),
                            "clusters": int(C),
                            "umap_dim": int(umap_dim) if use_umap else 0,
                            "recourse_steps": int(recourse_steps),
                            "N": int(N),
                            "deltas": {int(f): t.detach().cpu() for f, t in ifc_deltas_to_save.items()},
                            "ifc_cache_metadata": ifc_meta,
                        },
                        ifc_path,
                    )
                except Exception as e:
                    log.warning(
                        "Failed to save IFC responses for K=%d λ=%.2e jl=%d C=%d U=%d: %s",
                        int(K), float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0, str(e),
                    )

                # Setting summary (computed on folds that were retrained)
                key5 = (float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0, int(K))
                rows = per_setting_rows.get(key5, [])
                if rows:
                    b = np.array([r["baseline_cv_loss"] for r in rows], dtype=float)
                    li = np.array([r["loss_if"] for r in rows], dtype=float)
                    lc = np.array([r["loss_ifc"] for r in rows], dtype=float)
                    lp = np.array([r.get("loss_pbrf", float("nan")) for r in rows], dtype=float)

                    mae_if = float(np.mean(np.abs(li - b)))
                    mae_ifc = float(np.mean(np.abs(lc - b)))
                    mae_pbrf = float(np.mean(np.abs(lp - b))) if np.isfinite(lp).any() else float("nan")
                    rmse_if = float(np.sqrt(np.mean((li - b) ** 2)))
                    rmse_ifc = float(np.sqrt(np.mean((lc - b) ** 2)))
                    rmse_pbrf = float(np.sqrt(np.mean((lp - b) ** 2))) if np.isfinite(lp).any() else float("nan")

                    def _corr(x, y):
                        m = np.isfinite(x) & np.isfinite(y)
                        if not m.any():
                            return float("nan"), float("nan")
                        return pearson_corr(x[m], y[m]), spearman_corr(x[m], y[m])

                    pear_if, spear_if = _corr(li, b)
                    pear_ifc, spear_ifc = _corr(lc, b)
                    pear_pbrf, spear_pbrf = _corr(lp, b)

                    setting_summaries[key5] = {
                        "K": int(K),
                        "lambda": float(lam),
                        "jl_dim": int(jl_dim),
                        "clusters": int(C),
                        "umap_dim": int(umap_dim) if use_umap else 0,
                        "n_folds": int(len(rows)),
                        "mae_if": mae_if,
                        "mae_ifc": mae_ifc,
                        "mae_pbrf": float(mae_pbrf),
                        "rmse_if": rmse_if,
                        "rmse_ifc": rmse_ifc,
                        "rmse_pbrf": float(rmse_pbrf),
                        "pearson_if_vs_cv": float(pear_if),
                        "spearman_if_vs_cv": float(spear_if),
                        "pearson_ifc_vs_cv": float(pear_ifc),
                        "spearman_ifc_vs_cv": float(spear_ifc),
                        "pearson_pbrf_vs_cv": float(pear_pbrf),
                        "spearman_pbrf_vs_cv": float(spear_pbrf),
                        "ifc_cache_metadata": ifc_meta,
                        "gt_subset_eval": bool(gt_enable),
                        "pbrf_enabled": bool(pbrf_enable),
                    }

                    log.info(
                        "Summary K=%d λ=%.2e jl=%d C=%d U=%d (subset=%s, pbrf=%s): mae_if=%.4g mae_ifc=%.4g mae_pbrf=%.4g pear_if=%.3f pear_ifc=%.3f pear_pbrf=%.3f",
                        int(K), float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0,
                        str(bool(gt_enable)),
                        str(bool(pbrf_enable)),
                        float(mae_if), float(mae_ifc), float(mae_pbrf) if np.isfinite(mae_pbrf) else float("nan"),
                        float(pear_if) if np.isfinite(pear_if) else float("nan"),
                        float(pear_ifc) if np.isfinite(pear_ifc) else float("nan"),
                        float(pear_pbrf) if np.isfinite(pear_pbrf) else float("nan"),
                    )

        # Write per-K CSV
        out_csv = path / f"kfold_cv_compare_K{int(K)}.csv"
        rows_k = [r for r in all_rows if int(r.get("K", -1)) == int(K)]
        if rows_k:
            fieldnames = sorted({k for r in rows_k for k in r.keys()})
            with open(out_csv, "w", newline="") as f:
                w = csv.DictWriter(f, fieldnames=fieldnames)
                w.writeheader()
                w.writerows(rows_k)

        _log_structured(
            log,
            "k_done",
            {"K": int(K), "rows": int(len(rows_k)), "time_sec": float(time.time() - tK0), "csv": str(out_csv)},
        )

    runtime = time.time() - t0

    # Persist outputs
    with open(path / "per_fold.json", "w") as f:
        json.dump([_serialize(r) for r in all_rows], f, indent=2)
    with open(path / "settings_summary.json", "w") as f:
        json.dump([_serialize(v) for v in setting_summaries.values()], f, indent=2)
    if setting_summaries:
        fieldnames = sorted({k for s in setting_summaries.values() for k in s.keys()})
        with open(path / "settings_summary.csv", "w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=fieldnames)
            w.writeheader()
            w.writerows([_serialize(v) for v in setting_summaries.values()])

    summary = {
        "runtime_sec": float(runtime),
        "time_train_full": float(time_train_full),
        "n_rows": int(len(all_rows)),
        "Ks": [int(k) for k in Ks],
        "damping_grid": [float(l) for l in damping_grid],
        "skip_ifc": bool(skip_ifc),
        "gt_subset_eval": bool(gt_enable),
        "pbrf_enabled": bool(pbrf_enable),
        "output_dir": str(path),
    }
    diagnostics = {"summary": summary, "settings": [_serialize(v) for v in setting_summaries.values()]}
    with open(path / "run_summary.json", "w") as f:
        json.dump(_serialize({"summary": summary, "diagnostics": diagnostics}), f, indent=2)

    log.info("kfold cv compare done in %.2fs (rows=%d). Outputs in %s", runtime, len(all_rows), str(path))
    return {
        "table": all_rows,
        "table_nested": all_rows_nested,
        "summary": summary,
        "settings": list(setting_summaries.values()),
        "diagnostics": diagnostics,
    }
