# logic/experiment.py
import os, sys, copy, importlib
from typing import Dict, Any, Tuple, List, Optional, Callable
import numpy as np
import torch
import matplotlib.pyplot as plt

try:
    from utils.csv_utils import sort_and_save  # for CSV output
except Exception:
    sort_and_save = None

try:
    from utils.csv_utils import save_results_for_c
except Exception:
    save_results_for_c = None

try:
    import seaborn as sns
    _HAS_SNS = True
except Exception:
    _HAS_SNS = False

if __package__ in (None, ""):
    THIS_DIR = os.path.dirname(__file__)
    PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
    if PROJECT_ROOT not in sys.path:
        sys.path.insert(0, PROJECT_ROOT)
    from logic.optimization import optimize_bilevel
else:
    from .optimization import optimize_bilevel

SETTING = (os.environ.get("SETTINGS") or os.environ.get("settings", "imperfect_signal")).strip()
PARAM_TO_VARY = os.environ.get("PARAM", "alpha").strip()
RESULTS_DIR = os.environ.get("RESULTS_DIR", "results_csv").strip()
_grid = os.environ.get("GRID", "")
GRID_VALUES = [float(x) for x in _grid.split(",")] if _grid else [0.1, 0.25, 0.5, 1, 1.5, 2]

# Normalize flag for regrets/distances
NORMALIZE = bool(int(os.environ.get("NORMALIZE", "1")))

# Eval sampler config (fixed, held-out)
EVAL_NS = int(os.environ.get("EVAL_NS", "8192"))
EVAL_SEED = int(os.environ.get("EVAL_SEED", "12345"))
EVAL_ANTITHETIC = bool(int(os.environ.get("EVAL_ANTITHETIC", "1")))

# Training CRN config
CRN_REFRESH = int(os.environ.get("CRN_REFRESH", "20"))  # reuse same z for 20 steps
CRN_ANTITHETIC = bool(int(os.environ.get("CRN_ANTITHETIC", "1")))
BASE_SEED = int(os.environ.get("BASE_SEED", "0"))

# Outer steps override
OUTER_STEPS = os.environ.get("OUTER_STEPS")
OUTER_STEPS = int(OUTER_STEPS) if OUTER_STEPS is not None else None

# LaTeX symbol for legend
PARAM_SYMBOLS = {
    "r": r"r", "c": r"c", "sigma": r"\sigma", "sigma1": r"\sigma_1", "sigma2": r"\sigma_2",
    "alpha": r"\alpha", "beta": r"\beta", "gamma": r"\gamma", "rho": r"\rho",
    "v": r"v", "U_res": r"U_{\mathrm{res}}", "ell": r"\ell",
    "s": r"s", "a0": r"a_0", "w_min": r"w_{\min}",
}
SYMBOL = PARAM_SYMBOLS.get(PARAM_TO_VARY, PARAM_TO_VARY)

X_LABEL = "Iteration"
Y_LABEL_U1 = r"$u_1$ relative gap"
Y_LABEL_U2 = r"$u_2$ relative gap"
Y_LABEL_A  = r"$a$ relative distance"
Y_LABEL_T  = r"$t$ relative distance"
_EPS = 1e-12

torch.manual_seed(0)

# ---------- Eval samplers (Sobol QMC + antithetic) ----------
def _sobol_u(nsamples: int, device, dtype, seed: int) -> torch.Tensor:
    from torch.quasirandom import SobolEngine
    sob = SobolEngine(dimension=1, scramble=True, seed=seed)
    u = sob.draw(nsamples).squeeze(1).to(device=device, dtype=dtype)
    return u.clamp_(1e-7, 1 - 1e-7)

def _eval_z_logistic(ns: int, dtype, device, seed: int, antithetic: bool) -> torch.Tensor:
    m = ns // 2 if antithetic else ns
    u = _sobol_u(m, device, dtype, seed)
    z = torch.log(u) - torch.log1p(-u)
    return torch.cat([z, -z], dim=0) if antithetic else z

def _eval_z_laplace(ns: int, dtype, device, seed: int, antithetic: bool) -> torch.Tensor:
    m = ns // 2 if antithetic else ns
    u = _sobol_u(m, device, dtype, seed)
    z = torch.where(u <= 0.5, torch.log(2*u), -torch.log(2*(1-u)))
    return torch.cat([z, -z], dim=0) if antithetic else z

def _eval_z_normal(ns: int, dtype, device, seed: int, antithetic: bool) -> torch.Tensor:
    m = ns // 2 if antithetic else ns
    u = _sobol_u(m, device, dtype, seed)
    z = torch.sqrt(torch.tensor(2.0, device=device, dtype=dtype)) * torch.special.erfinv(2*u - 1)
    return torch.cat([z, -z], dim=0) if antithetic else z

def _pick_eval_sampler(mod) -> Callable[[int, torch.dtype, torch.device, int, bool], torch.Tensor]:
    # Prefer module-exposed signal type; default logistic
    if hasattr(mod, "sample_laplace_z"):
        return _eval_z_laplace
    if hasattr(mod, "sample_normal_z"):
        return _eval_z_normal
    return _eval_z_logistic

# ================== Helpers ==================
def _norm(x: torch.Tensor) -> float:
    return float(torch.norm(x).item()) if isinstance(x, torch.Tensor) else float(abs(x))

def _extract_t_vector(entry: Dict[str, float], length: int) -> Optional[np.ndarray]:
    """Reconstruct t from 't[i]' keys in trace entry with expected length; returns None if missing."""
    vals: List[float] = []
    for i in range(length):
        key = f"t[{i}]"
        if key not in entry:
            return None
        vals.append(float(entry[key]))
    return np.asarray(vals, dtype=float)

def _setup_figure():
    if _HAS_SNS:
        sns.set_style("ticks"); sns.set_context("talk")
    plt.figure(figsize=(9, 6))

def _style_and_save(ylabel: str, filename: str):
    plt.xlabel(X_LABEL, fontsize=24)
    plt.ylabel(ylabel, fontsize=24)
    plt.xscale("log"); plt.yscale("log")
    plt.grid(True, which="major", linestyle="--", linewidth=0.8, alpha=0.6)
    plt.grid(False, which="minor")
    leg = plt.legend(fontsize=16, frameon=True)
    leg.get_frame().set_edgecolor("black"); leg.get_frame().set_linewidth(1.2); leg.get_frame().set_alpha(1.0)
    ax = plt.gca()
    for spine in ax.spines.values():
        spine.set_visible(True); spine.set_linewidth(1.5); spine.set_color("black")
    ax.patch.set_edgecolor("black"); ax.patch.set_linewidth(1.5)
    ax.tick_params(width=1.2)
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight", pad_inches=0.05)
    print(f"✅ Saved plot to {filename}")
    plt.show()

def _to_param_tensor(x, fallback: List[float]) -> torch.Tensor:
    """Coerce DEFAULT_* into a 1D float32 tensor."""
    if isinstance(x, torch.Tensor):
        return x.detach().clone().to(dtype=torch.float32)
    return torch.tensor(fallback if isinstance(x, (list, tuple)) else [float(x)], dtype=torch.float32)

def _to_mask_tensor(mask_like, length: int, device=None, dtype=torch.float32) -> torch.Tensor:
    """
    Coerce a provided mask (tensor/list/scalar) to a 1D float tensor of given length.
    Defaults to all-ones if None.
    """
    if isinstance(mask_like, torch.Tensor):
        m = mask_like.detach().flatten().to(dtype=dtype)
    elif mask_like is None:
        m = torch.ones(length, dtype=dtype)
    elif isinstance(mask_like, (list, tuple, np.ndarray)):
        m = torch.tensor(mask_like, dtype=dtype).flatten()
    else:
        # treat as scalar
        m = torch.full((length,), float(mask_like), dtype=dtype)
    if m.numel() < length:
        m = torch.nn.functional.pad(m, (0, length - m.numel()), value=1.0)
    elif m.numel() > length:
        m = m[:length]
    if device is not None:
        m = m.to(device=device)
    return m

# ================== Main ==================
def main():
    module_path = f"logic.settings.{SETTING}"
    mod = importlib.import_module(module_path)

    # ------ Initial defaults from each module ------
    base_t_def = getattr(mod, "DEFAULT_T_INIT", torch.tensor([0.5], dtype=torch.float32))
    base_a_def = getattr(mod, "DEFAULT_A_INIT", torch.tensor([0.0], dtype=torch.float32))
    base_t = _to_param_tensor(base_t_def, [0.5]).requires_grad_(True)
    base_a = _to_param_tensor(base_a_def, [0.0]).requires_grad_(True)
    t_len = int(base_t.numel())

    base_params = copy.deepcopy(getattr(mod, "DEFAULT_PARAMS", {}))

    # Per-setting masks (training + metric); default to "all ones" (train & measure all)
    t_train_mask = _to_mask_tensor(getattr(mod, "T_TRAIN_MASK", None), t_len, device=base_t.device)
    t_metric_mask = _to_mask_tensor(getattr(mod, "T_METRIC_MASK", None), t_len, device=base_t.device)

    # Optional labels for nicer logs (e.g., ["s","b","d"])
    t_labels = getattr(mod, "T_LABELS", None)
    if isinstance(t_labels, (list, tuple)) and len(t_labels) != t_len:
        t_labels = None  # ignore mismatched labels

    # Prefer a single-arg projection wrapper if provided
    project_t_default = getattr(mod, "project_t_box_default", None)
    proj_t = project_t_default if callable(project_t_default) else None
    proj_a = getattr(mod, "project_a_box", None) if callable(getattr(mod, "project_a_box", None)) else None

    # Detect training sampler function on the module (for CRN)
    sample_train = None
    for name in ("sample_logistic_z", "sample_laplace_z", "sample_normal_z"):
        if hasattr(mod, name):
            sample_train = getattr(mod, name)
            break

    # Pick eval sampler consistent with the setting
    eval_sampler = _pick_eval_sampler(mod)

    # Storage for traces per grid value
    u1_traces: Dict[float, np.ndarray] = {}
    u2_traces: Dict[float, np.ndarray] = {}
    a_dist_traces: Dict[float, Optional[np.ndarray]] = {}
    t_dist_traces: Dict[float, Optional[np.ndarray]] = {}
    u1_star_values: Dict[float, float] = {}
    u2_star_values: Dict[float, float] = {}

    # ---- CSV accumulators ----
    rows_all: List[Dict[str, Any]] = []
    summary_rows: List[Dict[str, Any]] = []

    for val in GRID_VALUES:
        sp = copy.deepcopy(base_params)

        # vary requested parameter if present
        if PARAM_TO_VARY in sp:
            sp[PARAM_TO_VARY] = float(val)
        else:
            print(f"⚠️ PARAM '{PARAM_TO_VARY}' not in setting_parameters; skipping variation.")

        # init vars (apply projection without breaking grads)
        t = base_t.clone().detach().requires_grad_(True)
        if proj_t is not None:
            with torch.no_grad():
                projected = proj_t(t.detach().clone())
                if isinstance(projected, torch.Tensor) and projected.shape == t.shape:
                    t.copy_(projected)
        a = base_a.clone().detach().requires_grad_(True)
        if proj_a is not None:
            with torch.no_grad():
                projected_a = proj_a(a.detach().clone(), setting_parameters=sp)
                if isinstance(projected_a, torch.Tensor) and projected_a.shape == a.shape:
                    a.copy_(projected_a)

        # Build per-run masks on the right device/dtype
        t_train_mask_run = t_train_mask.to(device=t.device, dtype=t.dtype)
        t_metric_mask_run = t_metric_mask.to(device=t.device, dtype=t.dtype)

        # ==== Build single held-out eval batch (z_eval) consistent with the setting ====
        z_eval = eval_sampler(EVAL_NS, dtype=t.dtype, device=t.device, seed=EVAL_SEED, antithetic=EVAL_ANTITHETIC)

        # ==== Theoretical optimum (a*, t*) ====
        a_star, t_star = None, None
        if hasattr(mod, "get_theoretical_optimum"):
            try:
                res = mod.get_theoretical_optimum(t=t.detach(), setting_parameters=sp)
                if isinstance(res, (tuple, list)) and len(res) >= 2:
                    a_star, t_star = res[0], res[1]
            except Exception as e:
                print(f"ℹ️ get_theoretical_optimum failed for {SETTING} @ {val}: {e}")
                a_star, t_star = None, None

        # ==== u1* on the SAME eval batch ====
        try:
            star_params = copy.deepcopy(sp)
            star_params["nsamples"] = EVAL_NS
            if a_star is not None and t_star is not None:
                u1_star = float(mod.u1(
                    torch.as_tensor(a_star, dtype=t.dtype, device=t.device),
                    t_star, z=z_eval, **star_params
                ))
            else:
                u1_star = float("nan")
        except Exception as e:
            print(f"ℹ️ u1* eval failed for {SETTING} @ {val}: {e}")
            u1_star = float("nan")
        u1_star_values[val] = u1_star

        # ==== Build optimizer kwargs ====
        run_kwargs = dict(
            u1=mod.u1, u2=mod.u2, t=t, a=a,
            setting_parameters=sp,
            a_star=a_star, t_star=t_star,
            # masks
            t_train_mask=t_train_mask_run,
            t_metric_mask=t_metric_mask_run,
            t_labels=t_labels,
            # — training CRN —
            use_crn=True if sample_train is not None else False,
            make_z=(lambda ns, settings, dtype, device: sample_train(ns, dtype, device)) if sample_train else None,
            crn_refresh=CRN_REFRESH,
            crn_antithetic=CRN_ANTITHETIC,
            base_seed=BASE_SEED,
            log_with_eval=True,
            eval_make_z=lambda ns, settings, dtype, device: eval_sampler(EVAL_NS, dtype, device, seed=EVAL_SEED, antithetic=EVAL_ANTITHETIC),
            eval_nsamples=EVAL_NS,
            eval_seed=EVAL_SEED,
            outer_lr=5e-4,
            inner_max_steps=200,
            inner_grad_tol=1e-4,
        )
        if proj_t is not None:
            def _proj_t_safe(T: torch.Tensor) -> torch.Tensor:
                with torch.no_grad():
                    P = proj_t(T.detach().clone())
                return P
            run_kwargs["project_t"] = _proj_t_safe
        if proj_a is not None:
            def _proj_a_safe(A: torch.Tensor) -> torch.Tensor:
                with torch.no_grad():
                    P = proj_a(A.detach().clone(), setting_parameters=sp)
                return P
            run_kwargs["project_a"] = _proj_a_safe
        if OUTER_STEPS is not None:
            run_kwargs["outer_steps"] = int(OUTER_STEPS)

        # ==== Run ====
        t_final, a_final, u1_vals, u2_vals, trace = optimize_bilevel(**run_kwargs)
        u1_vals = np.asarray(u1_vals, dtype=float)
        u2_vals = np.asarray(u2_vals, dtype=float)
        u1_traces[val] = u1_vals
        u2_traces[val] = u2_vals

        # ==== u2* on SAME eval batch, but with FINAL t ====
        try:
            star_params = copy.deepcopy(sp)
            star_params["nsamples"] = EVAL_NS
            if a_star is not None:
                u2_star = float(mod.u2(
                    torch.as_tensor(a_star, dtype=t_final.dtype, device=t_final.device),
                    t_final.detach(), z=z_eval, **star_params
                ))
            else:
                u2_star = float("nan")
        except Exception as e:
            print(f"ℹ️ u2* eval failed for {SETTING} @ {val}: {e}")
            u2_star = float("nan")
        u2_star_values[val] = u2_star

        # ==== Distances from trace ====
        # a distance
        if a_star is not None and isinstance(a_star, torch.Tensor):
            a_star_norm = _norm(a_star)
            series: List[float] = []
            has_any = False
            for entry in trace:
                if "err_a_l2" in entry:
                    d = float(entry["err_a_l2"])
                    denom = max(a_star_norm, _EPS)
                    series.append(d / denom if NORMALIZE else d)
                    has_any = True
                elif "err_a_abs" in entry:
                    d = float(entry["err_a_abs"])
                    denom = max(abs(float(a_star)), _EPS)
                    series.append(d / denom if NORMALIZE else d)
                    has_any = True
                else:
                    series.append(np.nan)
            a_dist_traces[val] = np.asarray(series, dtype=float) if has_any else None
        else:
            a_dist_traces[val] = None

        # t distance
        if t_star is not None and isinstance(t_star, torch.Tensor):
            # compute norm of masked t_star for normalization
            t_star_np = t_star.detach().cpu().numpy().reshape(-1)
            mask_np = t_metric_mask_run.detach().cpu().numpy().reshape(-1)
            t_star_norm = float(np.linalg.norm(t_star_np * mask_np))
            series_t: List[float] = []
            has_any_t = False
            for entry in trace:
                # If the optimizer already computed a masked error, prefer it.
                if "err_t_l2" in entry:
                    d = float(entry["err_t_l2"])
                    denom = max(t_star_norm, _EPS)
                    series_t.append(d / denom if NORMALIZE else d)
                    has_any_t = True
                else:
                    t_vec = _extract_t_vector(entry, length=len(t_star_np))
                    if t_vec is not None:
                        d = float(np.linalg.norm((t_vec - t_star_np) * mask_np))
                        denom = max(t_star_norm, _EPS)
                        series_t.append(d / denom if NORMALIZE else d)
                        has_any_t = True
                    else:
                        series_t.append(np.nan)
            t_dist_traces[val] = np.asarray(series_t, dtype=float) if has_any_t else None
        else:
            t_dist_traces[val] = None

        # ----- CSV rows: full per-step trace (append AFTER u2_star_values exists) -----
        for entry in trace:
            row = dict(entry)  # contains step, u1, u2, C, inner_grad_norm, t[...] and err_* if present
            row.update({
                "setting": SETTING,
                "param": PARAM_TO_VARY,
                "param_value": float(val),
                "u1_star": float(u1_star_values[val]) if not np.isnan(u1_star_values[val]) else np.nan,
                "u2_star": float(u2_star_values[val]) if not np.isnan(u2_star_values[val]) else np.nan,
            })
            rows_all.append(row)

        # ----- CSV rows: final-step summary for this grid value -----
        if len(trace) > 0:
            last = dict(trace[-1])
            # optional normalized regrets at final step (if stars exist)
            try:
                u1_final = float(u1_traces[val][-1])
                u2_final = float(u2_traces[val][-1])
            except Exception:
                u1_final = last.get("u1", np.nan)
                u2_final = last.get("u2", np.nan)

            u1_star = u1_star_values[val]
            u2_star = u2_star_values[val]
            u1_gap_rel = (abs(u1_star - u1_final) / max(abs(u1_star), 1e-12)) if not np.isnan(u1_star) else np.nan
            u2_gap_rel = (abs(u2_star - u2_final) / max(abs(u2_star), 1e-12)) if not np.isnan(u2_star) else np.nan

            last.update({
                "setting": SETTING,
                "param": PARAM_TO_VARY,
                "param_value": float(val),
                "u1_final": float(u1_final),
                "u2_final": float(u2_final),
                "u1_star": float(u1_star) if not np.isnan(u1_star) else np.nan,
                "u2_star": float(u2_star) if not np.isnan(u2_star) else np.nan,
                "u1_gap_rel_final": float(u1_gap_rel),
                "u2_gap_rel_final": float(u2_gap_rel),
            })
            summary_rows.append(last)

    # ==== Figure 1: u1 normalized regret ====
    _setup_figure()
    denom_floor = 1.0
    if NORMALIZE:
        stars = np.array([v for v in u1_star_values.values() if not np.isnan(v)], dtype=float)
        if stars.size:
            denom_floor = max(_EPS, 1e-6 * np.median(np.abs(stars)))
    for val in GRID_VALUES:
        u1_star = u1_star_values[val]
        u1_vals = u1_traces[val]
        if not np.isnan(u1_star):
            gap_abs = np.abs(u1_star - u1_vals)
            curve = np.maximum(gap_abs / max(abs(u1_star), denom_floor), _EPS) if NORMALIZE else np.maximum(gap_abs, _EPS)
        else:
            curve = np.maximum(np.abs(u1_vals), _EPS)
        iters = np.arange(1, len(curve) + 1, dtype=float)
        label = fr"${SYMBOL}={val:g}$"
        plt.plot(iters, curve, linewidth=2, label=label)
    out_pdf = f"sweep_gap_u1_{SETTING}_{PARAM_TO_VARY}{'_rel' if NORMALIZE else ''}.pdf"
    _style_and_save(Y_LABEL_U1, out_pdf)

    # ==== Figure 2: u2 normalized regret (against u2(a*, t_final)) ====
    _setup_figure()
    denom_floor_u2 = 1.0
    if NORMALIZE:
        stars2 = np.array([v for v in u2_star_values.values() if not np.isnan(v)], dtype=float)
        if stars2.size:
            denom_floor_u2 = max(_EPS, 1e-6 * np.median(np.abs(stars2)))
    for val in GRID_VALUES:
        u2_star = u2_star_values[val]
        u2_vals = u2_traces[val]
        if not np.isnan(u2_star):
            gap_abs = np.abs(u2_star - u2_vals)
            curve = np.maximum(gap_abs / max(abs(u2_star), denom_floor_u2), _EPS) if NORMALIZE else np.maximum(gap_abs, _EPS)
        else:
            curve = np.maximum(np.abs(u2_vals), _EPS)
        iters = np.arange(1, len(curve) + 1, dtype=float)
        label = fr"${SYMBOL}={val:g}$"
        plt.plot(iters, curve, linewidth=2, label=label)
    out_pdf = f"sweep_gap_u2_{SETTING}_{PARAM_TO_VARY}{'_rel' if NORMALIZE else ''}.pdf"
    _style_and_save(Y_LABEL_U2, out_pdf)

    # ==== Figure 3: ||a - a*|| / ||a*|| ====
    any_a = any((a_dist_traces[val] is not None) for val in GRID_VALUES)
    if any_a:
        _setup_figure()
        for val in GRID_VALUES:
            dist = a_dist_traces[val]
            if dist is None:
                continue
            curve = np.asarray(dist, dtype=float)
            curve = np.where(np.isfinite(curve), np.maximum(curve, _EPS), np.nan)
            iters = np.arange(1, len(curve) + 1, dtype=float)
            label = fr"${SYMBOL}={val:g}$"
            plt.plot(iters, curve, linewidth=2, label=label)
        out_pdf = f"sweep_dist_a_{SETTING}_{PARAM_TO_VARY}{'_rel' if NORMALIZE else ''}.pdf"
        _style_and_save(Y_LABEL_A, out_pdf)
    else:
        print("ℹ️ Skipping a-distance plot (no a* / err_a traces available).")

    # ==== Figure 4: ||t - t*|| / ||t*|| (masked) ====
    any_t = any((t_dist_traces[val] is not None) for val in GRID_VALUES)
    if any_t:
        _setup_figure()
        for val in GRID_VALUES:
            dist = t_dist_traces[val]
            if dist is None:
                continue
            curve = np.asarray(dist, dtype=float)
            curve = np.where(np.isfinite(curve), np.maximum(curve, _EPS), np.nan)
            iters = np.arange(1, len(curve) + 1, dtype=float)
            label = fr"${SYMBOL}={val:g}$"
            plt.plot(iters, curve, linewidth=2, label=label)
        out_pdf = f"sweep_dist_t_{SETTING}_{PARAM_TO_VARY}{'_rel' if NORMALIZE else ''}.pdf"
        _style_and_save(Y_LABEL_T, out_pdf)
    else:
        print("ℹ️ Skipping t-distance plot (no t* / trace values available).")

    # ====== Save CSVs (aggregate + per-c) ======
    from collections import defaultdict

    ROOT = "/scratch/user/aaryabookseller/Research/archive/csv"
    base_dir = os.path.join(ROOT, SETTING, f"{PARAM_TO_VARY}_rel")

    try:
        os.makedirs(base_dir, exist_ok=True)
    except Exception as e:
        print(f"⚠️ Could not create results dir '{base_dir}': {e}")
        return

    # ---- Aggregate (all values together) ----
    trace_path   = os.path.join(base_dir, "trace_all.csv")
    summary_path = os.path.join(base_dir, "summary_all.csv")

    if sort_and_save is None:
        import csv
        if rows_all:
            with open(trace_path, "w", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=sorted({k for r in rows_all for k in r.keys()}))
                writer.writeheader(); writer.writerows(rows_all)
            print(f"✅ Saved aggregate trace CSV to {trace_path}")
        if summary_rows:
            with open(summary_path, "w", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=sorted({k for r in summary_rows for k in r.keys()}))
                writer.writeheader(); writer.writerows(summary_rows)
            print(f"✅ Saved aggregate summary CSV to {summary_path}")
    else:
        sort_and_save(rows_all,    file_name=trace_path,   sort_by="param_value", ascending=True)
        sort_and_save(summary_rows, file_name=summary_path, sort_by="param_value", ascending=True)

    # ---- Per-c CSVs ----
    def _slug_float(x: float) -> str:
        try:
            xf = float(x)
            return str(int(xf)) if xf.is_integer() else str(xf).replace(".", "p")
        except Exception:
            return str(x).replace(".", "p")

    # Group rows by param_value
    by_val = defaultdict(list)
    for r in rows_all:
        by_val[r.get("param_value", float("nan"))].append(r)


    summary_by_val = defaultdict(list)
    for r in summary_rows:
        summary_by_val[r.get("param_value", float("nan"))].append(r)

    for val, rows in by_val.items():
        if PARAM_TO_VARY == "c" and save_results_for_c is not None:
            # full per-step trace for this c
            save_results_for_c(
                results=rows,
                c_value=val,
                root_dir=ROOT,
                setting=SETTING,
                file_prefix="trace",
                sort_by="step",   # assumes 'step' in trace entries
                ascending=True,
            )
            # final-step summary
            last_list = summary_by_val.get(val, [])
            one_row = [last_list[-1]] if last_list else []
            save_results_for_c(
                results=one_row,
                c_value=val,
                root_dir=ROOT,
                setting=SETTING,
                file_prefix="summary",
                sort_by=None,
                ascending=True,
            )
        else:
            # fallback
            val_dir = os.path.join(base_dir, f"{PARAM_TO_VARY}={_slug_float(val)}")
            os.makedirs(val_dir, exist_ok=True)

            trace_csv = os.path.join(val_dir, "trace.csv")
            if sort_and_save is None:
                import csv
                with open(trace_csv, "w", newline="") as f:
                    writer = csv.DictWriter(f, fieldnames=sorted({k for r in rows for k in r.keys()}))
                    writer.writeheader(); writer.writerows(rows)
            else:
                sort_and_save(rows, file_name=trace_csv, sort_by="step", ascending=True)

            summ_rows = summary_by_val.get(val, [])
            if summ_rows:
                summary_csv = os.path.join(val_dir, "summary.csv")
                if sort_and_save is None:
                    import csv
                    with open(summary_csv, "w", newline="") as f:
                        writer = csv.DictWriter(f, fieldnames=sorted({k for r in summ_rows for k in r.keys()}))
                        writer.writeheader(); writer.writerows([summ_rows[-1]])
                else:
                    sort_and_save([summ_rows[-1]], file_name=summary_csv, sort_by=None, ascending=True)

if __name__ == "__main__":
    torch.manual_seed(0)
    main()