from typing import Optional, Tuple, Any, Dict, List
import copy
import csv
import json
import logging
import math
import os
import platform
import subprocess
import threading
import time
import traceback
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import sys, subprocess
from typing import Union
import psutil  # optional but preferred for RAM info

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset

from ..evals.metrics import pearson_corr, spearman_corr
import torch.multiprocessing as mp
from ..estimators.if_compressed import (
    IFCompressedCfg,
    build_ifc_cache,
    delta_ifc_from_cache,
)
from ..estimators.helpers import BaseObjective, CGInfluenceModule, eval_on_indices, flatten_params, make_folds, set_params_from_vector
from ..utils.seed import set_seed
from ..utils.train import train_model
from ..utils.logging import setup_logging
from torch.profiler import profile, ProfilerActivity

class _SimpleObjective(BaseObjective):
    def __init__(self, loss_fn: nn.Module, weight_decay: float = 0.0):
        self.loss_fn = loss_fn
        self.weight_decay = 0.0

    def train_outputs(self, model: nn.Module, batch):
        x, _ = batch
        return model(x)

    def train_loss_on_outputs(self, outputs: torch.Tensor, batch):
        _, y = batch
        return self.loss_fn(outputs, y)

    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        return self.weight_decay * torch.sum(params ** 2)

    def test_loss(self, model: nn.Module, params: torch.Tensor, batch):
        x, y = batch
        return self.loss_fn(model(x), y)


def _try_get_git_commit() -> str:
    try:
        out = subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL)
        return out.decode().strip()
    except Exception:
        return "unknown"


def _extract_labels(dataset, explicit=None):
    if explicit is not None:
        arr = np.asarray(explicit)
        if arr.shape[0] == len(dataset):
            return arr
    for attr in ("targets", "labels", "y"):
        if hasattr(dataset, attr):
            data = getattr(dataset, attr)
            try:
                arr = np.asarray(data)
                if arr.shape[0] == len(dataset):
                    return arr
            except Exception:
                continue
    try:
        arr = np.array([dataset[i][1] for i in range(len(dataset))])
        return arr
    except Exception:
        return None


def _entropy_from_counts(counts: np.ndarray) -> float:
    total = float(np.sum(counts))
    if total <= 0:
        return float("nan")
    probs = counts[counts > 0] / total
    if probs.size == 0:
        return float("nan")
    return float(-np.sum(probs * np.log(probs + 1e-12)))


def _kl_divergence(counts_p: np.ndarray, counts_q: np.ndarray) -> float:
    counts_p = np.asarray(counts_p, dtype=float)
    counts_q = np.asarray(counts_q, dtype=float)
    L = max(counts_p.size, counts_q.size)
    if L == 0:
        return float("nan")
    counts_p = np.pad(counts_p, (0, L - counts_p.size))
    counts_q = np.pad(counts_q, (0, L - counts_q.size))
    sum_p = counts_p.sum()
    sum_q = counts_q.sum()
    if sum_p <= 0 or sum_q <= 0:
        return float("nan")
    p = counts_p / sum_p
    q = counts_q / sum_q
    mask = (p > 0) & (q > 0)
    if not np.any(mask):
        return float("nan")
    return float(np.sum(p[mask] * (np.log(p[mask]) - np.log(q[mask]))))


def _topk_counts(counts: Counter, k: int = 5) -> Dict[Any, int]:
    return dict(counts.most_common(k))


def _compute_dataset_stats(dataset, labels) -> Dict[str, Any]:
    N = len(dataset)
    if labels is None or len(labels) != N:
        return {
            "N": int(N),
            "n_classes": int(0),
            "label_entropy": float("nan"),
            "class_counts": {},
        }
    counts = Counter(labels.tolist()) if isinstance(labels, np.ndarray) else Counter(labels)
    arr_counts = np.array([counts[c] for c in sorted(counts)], dtype=float)
    return {
        "N": int(N),
        "n_classes": int(len(counts)),
        "label_entropy": _entropy_from_counts(arr_counts),
        "class_counts": {int(k): int(v) for k, v in counts.items()},
    }


def _cosine_torch(a: torch.Tensor, b: torch.Tensor) -> float:
    na = a.view(-1).norm().item()
    nb = b.view(-1).norm().item()
    if na == 0 or nb == 0:
        return float("nan")
    return float((a.view(-1) @ b.view(-1)).item() / (na * nb))


def _angle_from_cos(cos_val: float) -> float:
    if not math.isfinite(cos_val):
        return float("nan")
    clipped = max(-1.0, min(1.0, cos_val))
    return float(math.degrees(math.acos(clipped)))


def _flatten_model_params(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.detach().reshape(-1) for p in model.parameters() if p.requires_grad])


def _accuracy_metric(logits: torch.Tensor, targets: torch.Tensor) -> float:
    if logits.ndim >= 2 and logits.shape[-1] > 1:
        preds = torch.argmax(logits, dim=1)
    else:
        preds = (logits.view(-1) > 0).long()
    return float((preds == targets).float().mean().item())



def _collect_env_metadata(cfg: dict, model: nn.Module, dataset, labels, train_cfg: dict, seed: int) -> Dict[str, Any]:
    param_vec = _flatten_model_params(model)
    git_commit = _try_get_git_commit()
    meta = {
        "seed": int(seed),
        "git_commit": git_commit,
        "code_version": str(cfg.get("code_version", git_commit)),
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda or "cpu",
        "cudnn_deterministic": bool(getattr(torch.backends.cudnn, "deterministic", False)),
        "device": str(next(model.parameters()).device),
        "dtype": str(next(model.parameters()).dtype),
    }
    dataset_stats = _compute_dataset_stats(dataset, labels)
    meta.update({f"dataset_{k}": v for k, v in dataset_stats.items()})
    meta.update({
        "param_count": int(param_vec.numel()),
        "param_norm": float(param_vec.norm().item()),
        "grad_clip": bool(train_cfg.get("grad_clip", False)),
        "epochs": int(train_cfg.get("epochs", 0)),
        "lr": float(train_cfg.get("lr", 0.0)),
        "weight_decay": float(train_cfg.get("weight_decay", 0.0)),
    })
    return meta




def _serialize(obj):
    if isinstance(obj, (float, int, str, bool)) or obj is None:
        return obj
    if isinstance(obj, torch.Tensor):
        return obj.detach().cpu().tolist()
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, dict):
        return {str(k): _serialize(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple, set)):
        return [_serialize(v) for v in obj]
    return str(obj)


def _log_structured(log: Optional[logging.Logger], label: str, payload: Any) -> None:
    if log is None:
        return
    try:
        msg = json.dumps(_serialize(payload), indent=2, sort_keys=True)
    except TypeError:
        msg = str(_serialize(payload))
    log.info("%s:\n%s", label, msg)


def _grad_mean_over_indices(model: nn.Module, dataset, loss_fn: nn.Module, idxs: np.ndarray) -> torch.Tensor:
    device = next(model.parameters()).device
    loader = DataLoader(Subset(dataset, idxs.tolist()), batch_size=128, shuffle=False)
    params = [p for p in model.parameters() if p.requires_grad]
    total = 0
    g_sum = None
    model.eval()
    with torch.enable_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            loss = loss_fn(model(xb), yb)
            grads = torch.autograd.grad(loss, params, retain_graph=False)
            g_flat = torch.cat([g.reshape(-1) for g in grads])
            bs = xb.shape[0]
            total += bs
            g_sum = g_flat * bs if g_sum is None else g_sum + g_flat * bs
    return g_sum / max(total, 1)




def _save_checkpoint(output_dir: Path, all_rows_nested: dict, setting_summaries: dict, env_meta: dict, log: Optional[logging.Logger] = None) -> None:
    """Incrementally save current state to disk."""

    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Count total rows
    total_rows = sum(len(v) for v in all_rows_nested.values())
    
    # Save nested structure
    checkpoint_nested = output_dir / "per_fold_nested.json"
    with open(checkpoint_nested, "w") as f:
        serialized_nested = {
            f"lambda={k[0]:.2e}_jl={k[1]}_C={k[2]}_U={k[3]}_K={k[4]}": [_serialize(r) for r in v]
            for k, v in all_rows_nested.items()
        }
        json.dump(serialized_nested, f, indent=2)
    
    # Save flattened list
    all_rows = []
    for setting_key in sorted(all_rows_nested.keys()):
        all_rows.extend(all_rows_nested[setting_key])
    checkpoint_flat = output_dir / "per_fold.json"
    with open(checkpoint_flat, "w") as f:
        json.dump([_serialize(r) for r in all_rows], f, indent=2)
    
    # Save settings summary
    if setting_summaries:
        settings_list = list(setting_summaries.values())
        settings_json = output_dir / "settings_summary.json"
        with open(settings_json, "w") as f:
            json.dump([_serialize(s) for s in settings_list], f, indent=2)
        
        # CSV too
        csv_path = output_dir / "settings_summary.csv"
        fieldnames = sorted({key for stats in settings_list for key in stats.keys()})
        with open(csv_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(settings_list)
    
    # Save metadata checkpoint
    metadata_json = output_dir / "metadata_checkpoint.json"
    with open(metadata_json, "w") as f:
        json.dump(_serialize(env_meta), f, indent=2)
    
    if log:
        log.info(f"✓ Checkpoint saved: {total_rows} folds, {len(setting_summaries)} settings, {len(all_rows_nested)} configurations")


def run_ablation(
    cfg: dict,
    logger: Optional[logging.Logger] = None,
    device: Optional[Union[str, torch.device]] = None,
) -> dict:
    outdir = cfg.get("outdir", "runs")
    name = cfg.get("name", "ablation_run")
    path = Path(outdir) / name
    path.mkdir(parents=True, exist_ok=True)
    log = logger or setup_logging(cfg.get("log_level", "DEBUG"), log_dir=str(path))
    log.info("Ablation start")
    log.info("Config: %s", cfg)
    t0 = time.time()
    seed = int(cfg.get("seed", 10))
    set_seed(seed)

    # 1) Init: model/dataset/loss and training
    model: nn.Module = cfg["model"]
    dataset = cfg["dataset"]
    loss_fn: nn.Module = cfg["loss_fn"]
    if device is not None:
        model = model.to(device)
    model_device = next(model.parameters()).device
    N = len(dataset)
    log.info("Dataset size: N=%d", N)
    # Folds: either explicit fold sizes (counts) or K
    labels = cfg.get("labels", None)
    labels_arr = _extract_labels(dataset, labels)
    train_cfg = cfg.get("train", {"epochs": 5, "lr": 1e-2, "batch_size": 128})
    max_grad_norm = train_cfg.get("max_grad_norm", train_cfg.get("grad_clip", 5))
    if max_grad_norm is not None:
        try:
            max_grad_norm = float(max_grad_norm)
        except Exception:
            max_grad_norm = None
        if max_grad_norm is not None and max_grad_norm <= 0:
            max_grad_norm = None
    t_train0 = time.time()
    train_stats = train_model(
        model,
        dataset,
        loss_fn,
        epochs=int(train_cfg.get("epochs", 5)),
        lr=float(train_cfg.get("lr", 1e-2)),
        weight_decay=float(train_cfg.get("weight_decay", 0.0)),
        batch_size=int(train_cfg.get("batch_size", 128)),
        num_workers=int(train_cfg.get("num_workers", 0)),
        max_grad_norm=max_grad_norm,
        stop_on_nonfinite=True,
        logger=log,
    )
    time_train = time.time() - t_train0
    model.eval()
    env_meta = _collect_env_metadata(cfg, model, dataset, labels_arr, train_cfg, seed)
    
    try:
        save_path = path / "trained_model.pth"
        torch.save(model.state_dict(), save_path)
        log.info("✓ Model saved to %s", str(save_path))
    except Exception as e:
        log.warning("Failed to save model: %s", str(e))

    is_classifier = labels_arr is not None and np.issubdtype(np.asarray(labels_arr).dtype, np.integer)
    metrics_eval = {"accuracy": _accuracy_metric} if is_classifier else None
    train_eval = eval_on_indices(model, dataset, np.arange(N), loss_fn, metrics=metrics_eval)
    env_meta.update({
        "train_loss": float(train_eval["loss"]),
        "train_acc": float(train_eval.get("accuracy", float("nan"))),
        "val_loss": float("nan"),
        "val_acc": float("nan"),
        "early_stop": bool(cfg.get("early_stop", False)),
        "best_epoch": int(cfg.get("best_epoch", train_cfg.get("epochs", 0))),
        "time_train": float(time_train),
    })
    if isinstance(train_stats, dict):
        env_meta["grad_clip_count"] = int(train_stats.get("grad_clip_count", 0))
        env_meta["update_clip_count"] = int(train_stats.get("grad_clip_count", 0))
        env_meta["nan_loss_batches"] = int(train_stats.get("nan_batches", 0))
        env_meta["epochs_run"] = int(train_stats.get("epochs_run", train_cfg.get("epochs", 0)))
    else:
        env_meta["grad_clip_count"] = int(train_cfg.get("grad_clip_count", 0))
        env_meta["update_clip_count"] = int(train_cfg.get("update_clip_count", 0))
    params_data = [p.detach() for p in model.parameters() if p.requires_grad]
    env_meta["nan_in_params"] = bool(any(torch.isnan(p).any().item() for p in params_data))
    env_meta["overflow_detected"] = bool(any(torch.isinf(p).any().item() for p in params_data))
    _log_structured(log, "environment_metadata_initial", env_meta)

    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats(model_device)

    eval_time_total = 0.0
    eval_call_count = 0
    cg_time_if_total = 0.0
    cg_iters_if_total = 0
    cg_plateau_flags: List[bool] = []
    per_setting_rows: Dict[Tuple[float, int, int, int, int], List[dict]] = defaultdict(list)
    setting_summaries: Dict[Tuple[float, int, int, int, int], Dict[str, Any]] = {}
    all_rows_nested: Dict[Tuple[float, int, int, int, int], List[dict]] = defaultdict(list)
    nan_in_loss_flag = False

    # Persist raw IF / IFC responses (delta vectors)
    responses_dir = path / "responses"
    responses_dir.mkdir(parents=True, exist_ok=True)

    def timed_eval(indices: np.ndarray, metrics=None):
        nonlocal eval_time_total, eval_call_count
        t0_eval = time.time()
        res = eval_on_indices(model, dataset, indices, loss_fn, metrics=metrics)
        eval_time_total += time.time() - t0_eval
        eval_call_count += 1
        return res

    # --------------------
    # Setup grids (K-independent)
    # --------------------
    damping_grid = list(cfg.get("damping_grid", cfg.get("damping", [1e-3])))
    jl_dims = list(cfg.get("jl_dims", cfg.get("jl", [128])))
    clusters_grid = list(cfg.get("clusters", [50]))
    use_umap_default = bool(cfg.get("use_umap", True))
    umap_dims = list(cfg.get("umap_dims", [25]))

    cg_tol = float(cfg.get("cg_tol", cfg.get("tol", 1e-6)))
    cg_max = int(cfg.get("cg_max_iters", cfg.get("cg_iters", 1000)))
    cg_tol = min(cg_tol, 1e-6)
    cg_max = max(cg_max, 300)
    fisher = bool(cfg.get("fisher", True))

    gamma_fixed = float(cfg.get("gamma_fixed", 1.0))
    recourse_steps = int(cfg.get("recourse_steps", 0))

    # Optional: skip IFC computation (only IF)
    skip_ifc = bool(cfg.get("skip_ifc", False))

    # Shared full-data loader
    batch_size_eval = int(cfg.get("batch_size", 128))
    num_workers = int(cfg.get("num_workers", 4))
    all_loader = DataLoader(
        dataset,
        batch_size=batch_size_eval,
        shuffle=False,
        pin_memory=bool(torch.cuda.is_available()),
        persistent_workers=bool(num_workers > 0),
        num_workers=num_workers,
    )

    # Global mean gradient (used for keep-grad diagnostics)
    t_g0 = time.time()
    global_grad_mean = _grad_mean_over_indices(model, dataset, loss_fn, np.arange(N)).detach().to(model_device)
    t_global_grad = time.time() - t_g0
    _log_structured(log, "global_grad_metadata", {"time_sec": float(t_global_grad)})

    # --------------------
    # 1) Precompute IFC caches for the entire grid (fold/K-agnostic)
    # --------------------
    ifc_cache_by_setting: Dict[Tuple[float, int, int, int], Optional[dict]] = {}
    ifc_cache_meta_by_setting: Dict[Tuple[float, int, int, int], Dict[str, Any]] = {}

    if not skip_ifc:
        for lam in damping_grid:
            for jl_dim in jl_dims:
                for C in clusters_grid:
                    umap_dims_iter = umap_dims if use_umap_default else [0]
                    for umap_dim in umap_dims_iter:
                        if int(umap_dim) > int(C):
                            continue
                        use_umap = int(umap_dim) > 0
                        setting4 = (float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0)
                        if setting4 in ifc_cache_by_setting:
                            continue

                        ifc_cfg = IFCompressedCfg(
                            name="ifc",
                            jl_dim=int(jl_dim),
                            clusters=int(C),
                            damping=float(lam),
                            max_cg_iters=cg_max,
                            tol=cg_tol,
                            fisher=fisher,
                            normalize=True,
                            use_umap=use_umap,
                            umap_n_components=int(umap_dim) if use_umap else 0,
                            seed=seed,
                            collect_diagnostics=bool(cfg.get("ifc_collect_diagnostics", cfg.get("collect_diagnostics", False))),
                        )
                        log.info("Building IFC cache: λ=%.2e jl=%d C=%d U=%d", float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0)
                        t_ifc0 = time.time()
                        ifc_cache = build_ifc_cache(
                            model,
                            all_loader,
                            loss_fn,
                            ifc_cfg,
                            logger=log,
                            distortion_pairs=int(cfg.get("distortion_pairs", 2000)),
                        )
                        t_ifc = time.time() - t_ifc0
                        cache_timings = ifc_cache.get("timings", {}) if isinstance(ifc_cache, dict) else {}
                        cache_cluster_stats = ifc_cache.get("cluster_stats", {}) if isinstance(ifc_cache, dict) else {}
                        ifc_cache_by_setting[setting4] = ifc_cache
                        ifc_cache_meta_by_setting[setting4] = {
                            "lambda": float(lam),
                            "jl_dim": int(jl_dim),
                            "clusters": int(C),
                            "umap_dim": int(umap_dim) if use_umap else 0,
                            "build_time": float(t_ifc),
                            "timings": cache_timings,
                            "cluster_stats": cache_cluster_stats,
                        }
                        _log_structured(log, "ifc_cache_metadata", ifc_cache_meta_by_setting[setting4])
                        _save_checkpoint(path, all_rows_nested, setting_summaries, env_meta, log)
                        torch.cuda.empty_cache()
    else:
        _log_structured(log, "ifc_cache_skipped", {"skip_ifc": True})

    # --------------------
    # 2) Build IF modules per damping (K-agnostic)
    # --------------------
    objective = _SimpleObjective(loss_fn, weight_decay=float(train_cfg.get("weight_decay", 0.0)))
    mod_by_lambda: Dict[float, CGInfluenceModule] = {}
    for lam in damping_grid:
        mod_by_lambda[float(lam)] = CGInfluenceModule(
            model=model,
            objective=objective,
            train_loader=all_loader,
            test_loader=all_loader,
            device=model_device,
            damp=float(lam),
            gnh=fisher,
            tol=cg_tol,
            maxiter=cg_max,
        )

    # --------------------
    # 3) For each K: compute IF per fold, then evaluate IF vs IFC for every cache
    # --------------------
    for K in cfg.get("Ks", [5, 10, 20, 50, 100]):
        if K <= 1 or K > N:
            log.warning("Skipping invalid K=%d", K)
            continue

        folds = make_folds(N, K, stratify=labels, seed=seed)

        for lam in damping_grid:
            lam = float(lam)
            log.info("K=%d λ=%.3g: computing IF baselines", int(K), lam)
            mod = mod_by_lambda[lam]

            # IF exact baseline per fold (REMOVAL update, apply with +)
            IF: Dict[int, torch.Tensor] = {}
            IF_stats: Dict[int, Dict[str, Any]] = {}
            gS_cache: Dict[int, torch.Tensor] = {}

            t_gs0 = time.time()
            for f, S in enumerate(folds):
                idx = np.asarray(S, dtype=np.int64)
                if idx.size == 0:
                    continue
                if idx.min() < 0 or idx.max() >= N:
                    log.warning("Fold %d: invalid indices (min=%d max=%d N=%d)", f, int(idx.min()), int(idx.max()), int(N))
                    continue

                # Mean gradient over S (batched) -> much cheaper than per-sample streaming
                gS = _grad_mean_over_indices(model, dataset, loss_fn, idx).detach().to(model_device)
                t_cg_fold0 = time.time()
                sol = mod.inverse_hvp(gS)
                cg_time = time.time() - t_cg_fold0
                v_f = sol.get("ihvp") if isinstance(sol, dict) else None
                if v_f is None:
                    v_f = torch.zeros_like(gS)

                iters = int(sol.get("iterations", 0)) if isinstance(sol, dict) else 0
                resid = float(sol.get("final_residual", float("nan"))) if isinstance(sol, dict) else float("nan")
                hist = sol.get("residual_history", []) if isinstance(sol, dict) else []
                sol_norm = float(v_f.view(-1).norm().item())
                rhs_norm = float(gS.view(-1).norm().item())
                rayleigh = float(((v_f.view(-1) @ gS.view(-1)).item()) / max(sol_norm ** 2, 1e-12)) if sol_norm > 0 else float("nan")
                plateau = False
                if hist:
                    start_idx = max(0, int(0.9 * len(hist)) - 1)
                    tail = hist[start_idx:] or [hist[-1]]
                    if len(tail) >= 2:
                        diffs = [abs(tail[i] - tail[i - 1]) for i in range(1, len(tail))]
                        baseline = max(tail[:-1]) if tail[:-1] else tail[0]
                        improvement = max(diffs) / max(baseline, 1e-12) if diffs else 0.0
                        plateau = improvement < 1e-2

                delta_f = (len(idx) / N) * v_f.detach()
                IF[f] = delta_f.detach().cpu()
                IF_stats[f] = {
                    "iters": iters,
                    "residual": resid,
                    "residual_history": hist,
                    "cg_time": float(cg_time),
                    "rhs_norm": rhs_norm,
                    "sol_norm": sol_norm,
                    "rayleigh": rayleigh,
                    "rr_plateau": bool(plateau),
                    "stability": sol.get("stability_info") if isinstance(sol, dict) else None,
                    "verification": sol.get("verification_info") if isinstance(sol, dict) else None,
                    "ihvp_norm": sol_norm,
                }
                cg_time_if_total += cg_time
                cg_iters_if_total += iters
                cg_plateau_flags.append(bool(plateau))
                _log_structured(log, f"if_stats_fold_{f}", IF_stats[f])

                gS_cache[f] = gS.detach().cpu()

            t_gs = time.time() - t_gs0
            _log_structured(log, "fold_grad_metadata", {"K": int(K), "lambda": float(lam), "time_sec": float(t_gs)})
            torch.cuda.empty_cache()

            # Save IF responses (per-fold delta vectors) to disk
            try:
                if_path = responses_dir / f"IF_K{int(K)}_lam{float(lam):.2e}.pt"
                torch.save(
                    {
                        "K": int(K),
                        "lambda": float(lam),
                        "N": int(N),
                        "deltas": {int(f): t.detach().cpu() for f, t in IF.items()},
                        "stats": IF_stats,
                    },
                    if_path,
                )
            except Exception as e:
                log.warning("Failed to save IF responses for K=%d λ=%.2e: %s", int(K), float(lam), str(e))

            # Evaluate IF vs IFC for every (jl_dim, C, umap_dim) cache at this lambda
            if skip_ifc:
                settings4_iter = [(lam, 0, 0, 0)]
            else:
                settings4_iter = [k for k in ifc_cache_by_setting.keys() if abs(float(k[0]) - lam) < 1e-12]

            for setting4 in settings4_iter:
                _, jl_dim, C, umap_dim = setting4
                use_umap = int(umap_dim) > 0
                ifc_cache = None if skip_ifc else ifc_cache_by_setting.get(setting4)
                ifc_cache_metadata = ifc_cache_meta_by_setting.get(setting4, {
                    "lambda": float(lam),
                    "jl_dim": int(jl_dim),
                    "clusters": int(C),
                    "umap_dim": int(umap_dim) if use_umap else 0,
                    "build_time": float("nan"),
                    "timings": {},
                    "cluster_stats": {},
                })
                cluster_stats = ifc_cache_metadata.get("cluster_stats", {}) or {}

                def make_ifc_delta_for_fold(f: int) -> torch.Tensor:
                    if skip_ifc or ifc_cache is None:
                        return torch.zeros_like(IF[f]).to(model_device)
                    S = np.asarray(folds[f], dtype=np.int64)
                    out = delta_ifc_from_cache(ifc_cache, S.tolist(), device=model_device)
                    delta0 = out["delta_theta"]
                    if recourse_steps <= 0:
                        return delta0
                    # One-step correction: recompute g_S at θ+Δ and correct with IHVP
                    theta0 = flatten_params(model).detach().clone()
                    set_params_from_vector(model, theta0 + delta0)
                    gS_new = _grad_mean_over_indices(model, dataset, loss_fn, S).detach().to(model_device)
                    set_params_from_vector(model, theta0)
                    gS_old = gS_cache.get(f)
                    if gS_old is None:
                        return delta0
                    corr_sol = mod.inverse_hvp((gS_new - gS_old.to(model_device)).detach())
                    corr_vec = corr_sol.get("ihvp") if isinstance(corr_sol, dict) else None
                    if corr_vec is None:
                        corr_vec = torch.zeros_like(delta0)
                    return delta0 + (len(S) / N) * corr_vec.detach()

                # Evaluate per fold and log
                ifc_deltas_to_save: Dict[int, torch.Tensor] = {}
                for f, S in enumerate(folds):
                    S = np.asarray(S, dtype=np.int64)
                    if S.size == 0:
                        continue
                    theta = flatten_params(model).detach().clone()
                    theta_norm = float(theta.norm().item())

                    base = timed_eval(S, metrics=metrics_eval)
                    keep_mask = np.ones(N, dtype=bool)
                    keep_mask[S] = False
                    keep_idx = np.where(keep_mask)[0]
                    base_keep = timed_eval(keep_idx, metrics=metrics_eval) if keep_idx.size > 0 else {"loss": float("nan")}
                    if not math.isfinite(base["loss"]):
                        nan_in_loss_flag = True
                    if keep_idx.size > 0 and not math.isfinite(base_keep["loss"]):
                        nan_in_loss_flag = True

                    row = {
                        "fold": int(f),
                        "lambda": float(lam),
                        "jl_dim": int(jl_dim),
                        "clusters": int(C),
                        "K": int(K),
                        "umap_dim": int(umap_dim) if use_umap else 0,
                        "recourse": int(recourse_steps),
                        "gamma_policy": "fixed",
                        "time_per_sample_grads": float("nan"),
                        "samples_per_sec_grads": float("nan"),
                        "G_row_norm_mean": float("nan"),
                        "G_row_norm_std": float("nan"),
                        "G_row_norm_p95": float("nan"),
                        "grad_nan_inf": int(0),
                        "jl_rel_distortion_mean": float(ifc_cache.get("jl_rel_distortion_mean", float("nan"))) if isinstance(ifc_cache, dict) else float("nan"),
                        "jl_rel_distortion_p95": float(ifc_cache.get("jl_rel_distortion_p95", float("nan"))) if isinstance(ifc_cache, dict) else float("nan"),
                        "proj_distortion_mean": float(ifc_cache.get("umap_proj_distortion_mean", float("nan"))) if isinstance(ifc_cache, dict) else float("nan"),
                        "proj_distortion_p95": float(ifc_cache.get("umap_proj_distortion_p95", float("nan"))) if isinstance(ifc_cache, dict) else float("nan"),
                        "umap_trustworthiness": float(ifc_cache.get("umap_trustworthiness", float("nan"))) if isinstance(ifc_cache, dict) else float("nan"),
                        "silhouette": float((ifc_cache.get("silhouette") if isinstance(ifc_cache, dict) else None) or cluster_stats.get("silhouette", float("nan"))),
                        "between_within_ratio": float((ifc_cache.get("between_within_ratio") if isinstance(ifc_cache, dict) else None) or cluster_stats.get("between_within_ratio", float("nan"))),
                        "counts_cv": float((ifc_cache.get("counts_cv") if isinstance(ifc_cache, dict) else None) or cluster_stats.get("counts_cv", float("nan"))),
                        "cluster_size_min": int(cluster_stats.get("size_min", 0)),
                        "cluster_size_max": int(cluster_stats.get("size_max", 0)),
                        "cluster_size_mean": float(cluster_stats.get("size_mean", 0.0)),
                        "cluster_n_empty": int(cluster_stats.get("n_empty", 0)),
                        "proj_dim": int(ifc_cache.get("proj_dim", 0)) if isinstance(ifc_cache, dict) else 0,
                        "umap_used": bool(ifc_cache.get("umap_used", False)) if isinstance(ifc_cache, dict) else False,
                    }

                    # Fold audit
                    row["n_S"] = int(len(S))
                    row["n_keep"] = int(N - len(S))
                    row["min_idx"] = int(S.min()) if len(S) else -1
                    row["max_idx"] = int(S.max()) if len(S) else -1
                    row["unique_idx"] = bool(len(np.unique(S)) == len(S))
                    row["has_duplicates"] = not row["unique_idx"]
                    row["stratified"] = bool(labels_arr is not None)
                    if labels_arr is not None and len(S) > 0:
                        mix_counts = Counter(labels_arr[S])
                        row["class_mix_S"] = _topk_counts(mix_counts, k=5)
                        if env_meta.get("dataset_class_counts"):
                            classes_sorted = sorted(int(k) for k in env_meta["dataset_class_counts"].keys())
                            counts_S_vec = np.array([mix_counts.get(cls, 0) for cls in classes_sorted], dtype=float)
                            dataset_counts_vec = np.array([env_meta["dataset_class_counts"][cls] for cls in classes_sorted], dtype=float)
                            row["class_kl_S_vs_all"] = _kl_divergence(counts_S_vec, dataset_counts_vec)
                        else:
                            row["class_kl_S_vs_all"] = float("nan")
                    else:
                        row["class_mix_S"] = {}
                        row["class_kl_S_vs_all"] = float("nan")

                    # IF exact
                    d_if = IF[f].to(model_device)
                    g_if = float(gamma_fixed)
                    upd_if = g_if * d_if
                    set_params_from_vector(model, theta + d_if)
                    r_if = timed_eval(S, metrics=metrics_eval)
                    r_if_keep = timed_eval(keep_idx, metrics=metrics_eval) if keep_idx.size > 0 else {"loss": float("nan")}
                    set_params_from_vector(model, theta)

                    # IFC
                    d_ifc = make_ifc_delta_for_fold(f)
                    ifc_deltas_to_save[int(f)] = d_ifc.detach().cpu()
                    g_ifc = float(gamma_fixed)
                    upd_ifc = g_ifc * d_ifc
                    set_params_from_vector(model, theta + d_ifc)
                    r_ifc = timed_eval(S, metrics=metrics_eval)
                    r_ifc_keep = timed_eval(keep_idx, metrics=metrics_eval) if keep_idx.size > 0 else {"loss": float("nan")}
                    set_params_from_vector(model, theta)

                    orig_loss = float(base["loss"])
                    row["orig_loss"] = orig_loss
                    row["loss_if"] = float(r_if["loss"])
                    row["loss_ifc"] = float(r_ifc["loss"])
                    row["gamma_if"] = g_if
                    row["gamma_ifc"] = g_ifc
                    row["||d_if||"] = float(d_if.norm().item())
                    row["||d_ifc||"] = float(d_ifc.norm().item())
                    row["||upd_if||"] = float(upd_if.norm().item())
                    row["||upd_ifc||"] = float(upd_ifc.norm().item())

                    cos_if_ifc = _cosine_torch(d_if, d_ifc)
                    row["cos_if_ifc"] = cos_if_ifc
                    row["angle_if_ifc_deg"] = _angle_from_cos(cos_if_ifc)

                    gS = gS_cache.get(f)
                    if gS is not None and keep_idx.size > 0 and (N - len(S)) > 0:
                        sum_keep = global_grad_mean * N - gS.to(model_device) * len(S)
                        g_keep = sum_keep / (N - len(S))
                    else:
                        g_keep = torch.zeros_like(d_if)
                    row["cos_if_keepgrad"] = _cosine_torch(d_if, g_keep)
                    row["cos_ifc_keepgrad"] = _cosine_torch(d_ifc, g_keep)
                    row["rel_upd_norm_if"] = row["||upd_if||"] / max(theta_norm, 1e-12)
                    row["rel_upd_norm_ifc"] = row["||upd_ifc||"] / max(theta_norm, 1e-12)

                    eps = 1e-12
                    row["d_loss_if"] = row["loss_if"] - orig_loss
                    row["d_loss_ifc"] = row["loss_ifc"] - orig_loss
                    row["rel_d_loss_if"] = row["d_loss_if"] / max(orig_loss, eps)
                    row["rel_d_loss_ifc"] = row["d_loss_ifc"] / max(orig_loss, eps)
                    row["success_if"] = int(row["d_loss_if"] > 0)
                    row["success_ifc"] = int(row["d_loss_ifc"] > 0)
                    if not math.isfinite(row["loss_if"]) or not math.isfinite(row["loss_ifc"]):
                        nan_in_loss_flag = True

                    base_keep_loss = float(base_keep.get("loss", float("nan")))
                    row["base_keep_loss"] = base_keep_loss
                    row["loss_keep_if"] = float(r_if_keep.get("loss", float("nan")))
                    row["loss_keep_ifc"] = float(r_ifc_keep.get("loss", float("nan")))
                    row["d_loss_keep_if"] = row["loss_keep_if"] - base_keep_loss if math.isfinite(base_keep_loss) else float("nan")
                    row["d_loss_keep_ifc"] = row["loss_keep_ifc"] - base_keep_loss if math.isfinite(base_keep_loss) else float("nan")

                    # CG diagnostics
                    row["cg_iters_if"] = int(IF_stats.get(f, {}).get("iters", 0))
                    row["cg_resid_if"] = float(IF_stats.get(f, {}).get("residual", float("nan")))
                    row["cg_time_if"] = float(IF_stats.get(f, {}).get("cg_time", float("nan")))
                    row["cg_rr_plateau"] = int(IF_stats.get(f, {}).get("rr_plateau", False))
                    row["cg_rhs_norm"] = float(IF_stats.get(f, {}).get("rhs_norm", float("nan")))
                    row["cg_sol_norm"] = float(IF_stats.get(f, {}).get("sol_norm", float("nan")))
                    row["rayleigh_if"] = float(IF_stats.get(f, {}).get("rayleigh", float("nan")))
                    if torch.cuda.is_available():
                        row["gpu_mem_max_mb"] = float(torch.cuda.max_memory_allocated(model_device) / 1e6)
                    else:
                        row["gpu_mem_max_mb"] = float("nan")

                    if gS is not None:
                        gS_dev = gS.to(model_device)
                        first_if = float((gS_dev @ upd_if).item())
                        second_if = 0.5 * float(IF_stats[f]["rayleigh"]) * (row["||upd_if||"] ** 2) if f in IF_stats and math.isfinite(float(IF_stats[f].get("rayleigh", float("nan")))) else float("nan")
                        row["pred1_dloss_if"] = first_if
                        row["pred2_dloss_if"] = second_if
                        row["pred_dloss_if"] = first_if + (second_if if math.isfinite(second_if) else 0.0)
                        row["pred_err_if"] = row["pred_dloss_if"] - row["d_loss_if"]
                        first_ifc = float((gS_dev @ upd_ifc).item())
                        second_ifc = 0.5 * float(IF_stats[f]["rayleigh"]) * (row["||upd_ifc||"] ** 2) if f in IF_stats and math.isfinite(float(IF_stats[f].get("rayleigh", float("nan")))) else float("nan")
                        row["pred1_dloss_ifc"] = first_ifc
                        row["pred2_dloss_ifc"] = second_ifc
                        row["pred_dloss_ifc"] = first_ifc + (second_ifc if math.isfinite(second_ifc) else 0.0)
                        row["pred_err_ifc"] = row["pred_dloss_ifc"] - row["d_loss_ifc"]
                    else:
                        row["pred1_dloss_if"] = float("nan")
                        row["pred2_dloss_if"] = float("nan")
                        row["pred_dloss_if"] = float("nan")
                        row["pred_err_if"] = float("nan")
                        row["pred1_dloss_ifc"] = float("nan")
                        row["pred2_dloss_ifc"] = float("nan")
                        row["pred_dloss_ifc"] = float("nan")
                        row["pred_err_ifc"] = float("nan")

                    # Line-search diagnostics (not used under fixed gamma)
                    row["ls_steps_if"] = int(0)
                    row["ls_found_increase_if"] = int(0)
                    row["ls_steps_ifc"] = int(0)
                    row["ls_found_increase_ifc"] = int(0)

                    log.info(
                        "fold=%d λ=%.2e jl=%d C=%d γ_if=%.2f γ_ifc=%.2f base=%.4f if=%.4f ifc=%.4f",
                        int(f), float(lam), int(jl_dim), int(C), float(g_if), float(g_ifc), float(row["orig_loss"]), float(row["loss_if"]), float(row["loss_ifc"])
                    )
                    _log_structured(log, "per_fold_row", row)

                    key5 = (float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0, int(K))
                    per_setting_rows[key5].append(row)
                    all_rows_nested[key5].append(row)
                    _save_checkpoint(path, all_rows_nested, setting_summaries, env_meta, log)
                    torch.cuda.empty_cache()

                # Save IFC responses (per-fold delta vectors) to disk
                try:
                    ifc_path = responses_dir / (
                        f"IFC_K{int(K)}_lam{float(lam):.2e}_jl{int(jl_dim)}_C{int(C)}_U{int(umap_dim) if use_umap else 0}.pt"
                    )
                    torch.save(
                        {
                            "K": int(K),
                            "lambda": float(lam),
                            "jl_dim": int(jl_dim),
                            "clusters": int(C),
                            "umap_dim": int(umap_dim) if use_umap else 0,
                            "recourse_steps": int(recourse_steps),
                            "N": int(N),
                            "deltas": {int(f): t.detach().cpu() for f, t in ifc_deltas_to_save.items()},
                            "ifc_cache_metadata": ifc_cache_metadata,
                        },
                        ifc_path,
                    )
                except Exception as e:
                    log.warning(
                        "Failed to save IFC responses for K=%d λ=%.2e jl=%d C=%d U=%d: %s",
                        int(K), float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0, str(e),
                    )

                # Aggregate stats after all folds for this (λ, jl_dim, C, umap_dim, K)
                key5 = (float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0, int(K))
                rows_setting = per_setting_rows.get(key5, [])
                topK = cfg.get("TopK", [1, 5, 10, 50, 100, 200, 500, 1000])

                def _topk_overlap(x: np.ndarray, y: np.ndarray, k: int) -> float:
                    m = np.isfinite(x) & np.isfinite(y)
                    x = x[m]
                    y = y[m]
                    n = x.size
                    if n == 0 or k <= 0:
                        return float("nan")
                    k = min(k, n)
                    idx_x = np.argsort(-x)[:k]
                    idx_y = np.argsort(-y)[:k]
                    inter = np.intersect1d(idx_x, idx_y)
                    return float(len(inter)) / float(k)

                orig = np.array([r["orig_loss"] for r in rows_setting], dtype=float)
                d_if = np.array([r["loss_if"] for r in rows_setting], dtype=float) - orig
                d_ifc = np.array([r["loss_ifc"] for r in rows_setting], dtype=float) - orig

                topk_if_ifc = {kk: _topk_overlap(d_if, d_ifc, int(kk)) for kk in topK}
                topk_results = {
                    "lambda": float(lam),
                    "jl_dim": int(jl_dim),
                    "clusters": int(C),
                    "umap_dim": int(umap_dim) if use_umap else 0,
                    "folds": int(K),
                    "topk_if_ifc": topk_if_ifc,
                }
                _log_structured(log, "topk_results", topk_results)

                def _mask_pair(x, y):
                    m = np.isfinite(x) & np.isfinite(y)
                    return x[m], y[m]

                x, y = _mask_pair(d_if, d_ifc)
                pear_if_ifc = pearson_corr(x, y) if x.size else float("nan")
                spear_if_ifc = spearman_corr(x, y) if x.size else float("nan")
                log.info(
                    "Correlations (K=%d, λ=%.2e, jl=%d, C=%d, U=%d): IF↔IFC P=%.3f S=%.3f",
                    int(K), float(lam), int(jl_dim), int(C), int(umap_dim) if use_umap else 0, float(pear_if_ifc), float(spear_if_ifc)
                )

                def _finite_stats(arr: np.ndarray) -> Tuple[float, float, float]:
                    arr = arr[np.isfinite(arr)]
                    if arr.size == 0:
                        return float("nan"), float("nan"), float("nan")
                    return float(arr.mean()), float(np.median(arr)), float(np.quantile(arr, 0.9))

                m_if = _finite_stats(d_if)
                m_ifc = _finite_stats(d_ifc)
                succ_if = float(np.mean([r.get("success_if", 0) for r in rows_setting])) if rows_setting else float("nan")
                succ_ifc = float(np.mean([r.get("success_ifc", 0) for r in rows_setting])) if rows_setting else float("nan")
                win_rate = float(np.mean([
                    (r.get("d_loss_if", float("nan")) > r.get("d_loss_ifc", float("nan")))
                    for r in rows_setting
                    if math.isfinite(r.get("d_loss_if", float("nan"))) and math.isfinite(r.get("d_loss_ifc", float("nan")))
                ])) if rows_setting else float("nan")

                pred_if = np.array([r.get("pred_dloss_if", float("nan")) for r in rows_setting], dtype=float)
                pred_ifc = np.array([r.get("pred_dloss_ifc", float("nan")) for r in rows_setting], dtype=float)
                mask_if = np.isfinite(pred_if) & np.isfinite(d_if)
                mask_ifc = np.isfinite(pred_ifc) & np.isfinite(d_ifc)
                pear_pred_if = pearson_corr(pred_if[mask_if], d_if[mask_if]) if mask_if.any() else float("nan")
                pear_pred_ifc = pearson_corr(pred_ifc[mask_ifc], d_ifc[mask_ifc]) if mask_ifc.any() else float("nan")
                spear_pred_if = spearman_corr(pred_if[mask_if], d_if[mask_if]) if mask_if.any() else float("nan")
                spear_pred_ifc = spearman_corr(pred_ifc[mask_ifc], d_ifc[mask_ifc]) if mask_ifc.any() else float("nan")

                setting_summary = {
                    "lambda": float(lam),
                    "jl_dim": int(jl_dim),
                    "clusters": int(C),
                    "umap_dim": int(umap_dim) if use_umap else 0,
                    "K": int(K),
                    "mean_d_loss_if": m_if[0],
                    "median_d_loss_if": m_if[1],
                    "p90_d_loss_if": m_if[2],
                    "mean_d_loss_ifc": m_ifc[0],
                    "median_d_loss_ifc": m_ifc[1],
                    "p90_d_loss_ifc": m_ifc[2],
                    "success_rate_if": succ_if,
                    "success_rate_ifc": succ_ifc,
                    "win_rate_if_vs_ifc": win_rate,
                    "pear_if_ifc": pear_if_ifc,
                    "spear_if_ifc": spear_if_ifc,
                    "pear_pred_vs_actual_if": pear_pred_if,
                    "pear_pred_vs_actual_ifc": pear_pred_ifc,
                    "spear_pred_vs_actual_if": spear_pred_if,
                    "spear_pred_vs_actual_ifc": spear_pred_ifc,
                    "ls_pred_vs_actual_corr_if": pear_pred_if,
                    "ls_pred_vs_actual_corr_ifc": pear_pred_ifc,
                    "topk_if_ifc": topk_if_ifc,
                    "topk_results": topk_results,
                    "ifc_cache_metadata": ifc_cache_metadata,
                }
                setting_summaries[key5] = setting_summary
                _log_structured(log, "setting_summary_partial", setting_summary)
                _save_checkpoint(path, all_rows_nested, setting_summaries, env_meta, log)

            # Write per-λ CSVs for this K
            out_rows = []
            for setting_key, rows in all_rows_nested.items():
                if abs(float(setting_key[0]) - float(lam)) < 1e-12 and int(setting_key[4]) == int(K):
                    out_rows.extend(rows)
            if out_rows:
                csv_out_dir = Path(outdir) / name
                csv_out_dir.mkdir(parents=True, exist_ok=True)
                csv_path = csv_out_dir / f"{lam}_{K}.csv"
                fieldnames = sorted({key for row in out_rows for key in row.keys()})
                with open(csv_path, "w", newline="") as f:
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    writer.writeheader()
                    writer.writerows(out_rows)

    runtime_total = time.time() - t0
    env_meta["nan_in_loss"] = bool(nan_in_loss_flag)
    env_meta["platform"] = platform.platform()
    env_meta["python_version"] = platform.python_version()
    env_meta["gamma_policy"] = "fixed"

    # Flatten nested structure for compatibility with existing downstream code
    all_rows = []
    for setting_key in sorted(all_rows_nested.keys()):
        all_rows.extend(all_rows_nested[setting_key])

    def _median_from_rows(field: str) -> float:
        vals = np.array([r.get(field, float("nan")) for r in all_rows], dtype=float)
        vals = vals[np.isfinite(vals)]
        return float(np.median(vals)) if vals.size else float("nan")

    env_meta["if_norm_median"] = _median_from_rows("||upd_if||")
    env_meta["ifc_norm_median"] = _median_from_rows("||upd_ifc||")
    env_meta["gamma_if_median"] = _median_from_rows("gamma_if")
    env_meta["gamma_ifc_median"] = _median_from_rows("gamma_ifc")

    if torch.cuda.is_available():
        env_meta["gpu_mem_max_mb_global"] = float(torch.cuda.max_memory_allocated(model_device) / 1e6)

    summary = {
        "n_rows": len(all_rows),
        "runtime_sec": runtime_total,
        "time_full_pipeline": runtime_total,
        "time_train": float(time_train),
        "time_eval_total": float(eval_time_total),
        "time_eval_per_call": float(eval_time_total / max(eval_call_count, 1)),
        "time_eval_per_fold_mean": float(eval_time_total / max(len(all_rows), 1)),
        "eval_calls": int(eval_call_count),
        "time_cg_if_total": float(cg_time_if_total),
        "cg_iters_if_total": int(cg_iters_if_total),
        "cg_iters_per_sec": float(cg_iters_if_total / max(cg_time_if_total, 1e-12)),
        "cg_rr_plateau_rate": float(np.mean(cg_plateau_flags)) if cg_plateau_flags else float("nan"),
    }

    throughput_vals = [r.get("samples_per_sec_grads", float("nan")) for r in all_rows]
    throughput_vals = [v for v in throughput_vals if math.isfinite(v)]
    summary["samples_per_sec_grads_mean"] = float(np.mean(throughput_vals)) if throughput_vals else float("nan")

    def _auc_from_curves(curves: Dict[Tuple[int, int, int], List[Tuple[float, float]]]) -> Dict[str, float]:
        auc = {}
        for key, pts in curves.items():
            pts = [(x, y) for x, y in pts if math.isfinite(x) and math.isfinite(y)]
            if len(pts) < 2:
                auc[str(key)] = float("nan")
                continue
            pts.sort(key=lambda t: t[0])
            xs = np.array([p[0] for p in pts], dtype=float)
            ys = np.array([p[1] for p in pts], dtype=float)
            auc[str(key)] = float(np.trapz(ys, xs))
        return auc

    curves_if: Dict[Tuple[int, int, int], List[Tuple[float, float]]] = defaultdict(list)
    curves_ifc: Dict[Tuple[int, int, int], List[Tuple[float, float]]] = defaultdict(list)
    heatmap_if: Dict[str, Dict[str, float]] = {}
    heatmap_ifc: Dict[str, Dict[str, float]] = {}
    settings_summary_list = []
    for key, stats in setting_summaries.items():
        _, jl_val, c_val, umap_val, _ = key
        other = (jl_val, c_val, umap_val)
        if math.isfinite(stats.get("success_rate_if", float("nan"))):
            curves_if[other].append((stats["lambda"], stats["success_rate_if"]))
        if math.isfinite(stats.get("success_rate_ifc", float("nan"))):
            curves_ifc[other].append((stats["lambda"], stats["success_rate_ifc"]))
        heatmap_if[f"jl{jl_val}_C{c_val}_U{umap_val}"] = {
            "mean_d_loss_if": stats["mean_d_loss_if"],
            "success_rate_if": stats["success_rate_if"],
        }
        heatmap_ifc[f"jl{jl_val}_C{c_val}_U{umap_val}"] = {
            "mean_d_loss_ifc": stats["mean_d_loss_ifc"],
            "success_rate_ifc": stats["success_rate_ifc"],
        }
        settings_summary_list.append(stats)

    auc_if = _auc_from_curves(curves_if)
    auc_ifc = _auc_from_curves(curves_ifc)
    summary["auc_success_if"] = auc_if
    summary["auc_success_ifc"] = auc_ifc
    summary["heatmap_if"] = heatmap_if
    summary["heatmap_ifc"] = heatmap_ifc

    _log_structured(log, "environment_metadata_final", env_meta)
    _log_structured(log, "run_summary", summary)
    _log_structured(log, "settings_summary", settings_summary_list)

    output_dir = Path(outdir) / name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Write nested structure to JSON for easy navigation
    per_fold_nested_json = output_dir / "per_fold_nested.json"
    with open(per_fold_nested_json, "w") as f:
        serialized_nested = {
            f"lambda={k[0]:.2e}_jl={k[1]}_C={k[2]}_U={k[3]}": [_serialize(r) for r in v]
            for k, v in all_rows_nested.items()
        }
        json.dump(serialized_nested, f, indent=2)
    
    # Write flat list for backward compatibility
    per_fold_json = output_dir / "per_fold.json"
    with open(per_fold_json, "w") as f:
        json.dump([_serialize(r) for r in all_rows], f, indent=2)
    settings_json = output_dir / "settings_summary.json"
    with open(settings_json, "w") as f:
        json.dump([_serialize(s) for s in settings_summary_list], f, indent=2)
    run_summary_json = output_dir / "run_summary.json"
    with open(run_summary_json, "w") as f:
        json.dump(_serialize({"summary": summary, "metadata": env_meta, "settings": settings_summary_list}), f, indent=2)
    # settings summary CSV
    if settings_summary_list:
        csv_path = output_dir / "settings_summary.csv"
        fieldnames = sorted({key for stats in settings_summary_list for key in stats.keys()})
        with open(csv_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(settings_summary_list)

    diagnostics = {
        "env": env_meta,
        "settings": settings_summary_list,
        "auc_success_if": auc_if,
        "auc_success_ifc": auc_ifc,
        "output_dir": str(output_dir),
    }

    _log_structured(log, "diagnostics", diagnostics)

    return {
        "table": all_rows,
        "table_nested": all_rows_nested,
        "summary": summary,
        "metadata": env_meta,
        "settings": settings_summary_list,
        "diagnostics": diagnostics
    }
