# peptide/plot_compare_unique_good_peptides.py
from __future__ import annotations

import argparse
import json
import os
import re
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

# Python 3.11+
import tomllib

from .peptide_env import Sequences, Policy
from .peptide_reward import LogReward

try:
    from tqdm import tqdm  # type: ignore
except Exception:
    tqdm = None

try:
    from scipy.signal import savgol_filter  # type: ignore
except Exception:
    savgol_filter = None


# -----------------------------
# Matplotlib preamble (ICML style)
# -----------------------------
matplotlib.rcParams.update({
    "font.family": "serif",
    "font.size": 14.0,
    "lines.linewidth": 2,
    "lines.antialiased": True,
    "axes.facecolor": "fdfdfd",
    "axes.edgecolor": "777777",
    "axes.linewidth": 1,
    "axes.titlesize": "medium",
    "axes.labelsize": "medium",
    "axes.axisbelow": True,
    "xtick.color": "333333",
    "xtick.labelsize": "medium",
    "xtick.direction": "in",
    "ytick.major.size": 0,
    "ytick.minor.size": 0,
    "ytick.major.pad": 6,
    "ytick.minor.pad": 6,
    "ytick.color": "333333",
    "ytick.labelsize": "medium",
    "ytick.direction": "in",
    "axes.grid": True,
    "grid.alpha": 0.3,
    "grid.linewidth": 1,
    "legend.fancybox": True,
    "legend.fontsize": "Small",
    "figure.facecolor": "1.0",
    "figure.edgecolor": "0.5",
    "hatch.linewidth": 0.1,
    "text.usetex": True,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})
plt.rcParams["text.latex.preamble"] = r"\usepackage{times, amsmath, amssymb}"


def my_formatter(x, pos):
    val_str = "{:g}".format(x)
    if np.abs(x) > 0 and np.abs(x) < 1:
        return val_str.replace("0", "", 1)
    return val_str


major_formatter = FuncFormatter(my_formatter)


def apply_format(ax):
    ax.xaxis.set_major_formatter(major_formatter)
    ax.yaxis.set_major_formatter(major_formatter)


def millions(x, pos):
    return rf"{x/1e6:g}M"


# -----------------------------
# Names + colors (ICML style)
# -----------------------------
field_to_name = {
    "tb": r"Trajectory Balance\\(Malkin et al., NeurIPS 2022)",
    "dtb": r"Divergent (\textbf{Ours})",
    "random": r"Random",
    "teacher_student": r"Adaptive Teacher\\(Kim et al., ICLR 2025)",
    "sa": r"Sibling Augmented\\(Madan et al., ICLR 2025)",
}

field_to_color = {
    "tb": "#b22222",
    "dtb": "#1f77b4",
    "random": "#808080",
    "teacher_student": "#911eb4",
    "sa": "#f58231",
}


# -----------------------------
# Legend helper (separado)
# -----------------------------
def save_legend(handles, labels, out_path: str | Path, ncol=2, fontsize=10):
    fig_leg = plt.figure(figsize=(6.2, 0.55))
    fig_leg.legend(
        handles, labels,
        loc="center",
        ncol=ncol,
        frameon=False,
        fontsize=fontsize,
        handlelength=2.2,
        columnspacing=1.4,
    )
    fig_leg.canvas.draw()
    fig_leg.savefig(out_path, bbox_inches="tight", pad_inches=0.0, dpi=300)
    plt.close(fig_leg)


# -------------------------
# Smoothing helper (SavGol)
# -------------------------
def _make_odd_window(w: int, n: int) -> int:
    w = int(w)
    if w < 3:
        w = 3
    if w % 2 == 0:
        w += 1
    if w > n:
        w = n if (n % 2 == 1) else (n - 1)
    if w < 3:
        w = 3 if n >= 3 else n
    return w


def smooth_curve(
    y: np.ndarray,
    *,
    window: int,
    poly: int,
    monotone: bool = False,
    floor: Optional[float] = None,
) -> np.ndarray:
    """
    Smooth a 1D curve using Savitzky-Golay (plot-only).
    - monotone=True enforces non-decreasing after smoothing (good for cumulative curves)
    - floor clips output to >= floor
    """
    y = np.asarray(y, dtype=float)
    n = y.size

    if savgol_filter is None or n < 5:
        out = y.copy()
        if floor is not None:
            out = np.maximum(out, floor)
        if monotone:
            out = np.maximum.accumulate(out)
        return out

    w = _make_odd_window(window, n)
    p = int(poly)
    if p >= w:
        p = max(1, w - 2)

    out = savgol_filter(y, window_length=w, polyorder=p, mode="interp")

    if floor is not None:
        out = np.maximum(out, floor)
    if monotone:
        out = np.maximum.accumulate(out)
    return out


# -------------------------
# TOML helpers (batch_size)
# -------------------------
def find_experiments_toml() -> Path:
    candidates = [
        Path("peptide/experiments.toml"),
        Path("experiments.toml"),
        Path("peptide/experiments/experiments.toml"),
    ]
    for c in candidates:
        if c.exists():
            return c
    raise FileNotFoundError(
        "Não encontrei peptide/experiments.toml (ou fallback). "
        "Coloque em peptide/experiments.toml ou ajuste os candidates em find_experiments_toml()."
    )


def load_batch_size_from_toml(toml_path: Path, run_id: str) -> int:
    """
    Busca batch_size para um run_id.

    Regras:
      1) se [[runs]] que casa com (run_id ou id) tiver batch_size, usa
      2) senão, usa [defaults].batch_size
      3) se nenhum existir, erro

    Nota: casa com r["run_id"] OU r["id"] para compatibilidade.
    """
    with open(toml_path, "rb") as f:
        cfg = tomllib.load(f)

    defaults = cfg.get("defaults", {}) or {}
    runs = cfg.get("runs", []) or []

    for r in runs:
        rid = str(r.get("run_id", "")) if r.get("run_id", None) is not None else ""
        rid2 = str(r.get("id", "")) if r.get("id", None) is not None else ""
        if rid == run_id or rid2 == run_id:
            if "batch_size" in r:
                return int(r["batch_size"])
            if "batch_size" in defaults:
                return int(defaults["batch_size"])
            raise KeyError(
                f"run_id={run_id} encontrado, mas batch_size não existe nem no [[runs]] nem em [defaults]."
            )

    if "batch_size" in defaults:
        return int(defaults["batch_size"])

    raise KeyError(f"run_id={run_id} não encontrado em {toml_path} e defaults.batch_size ausente.")


# -------------------------
# Filesystem helpers
# -------------------------
def read_json(path: Path) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def get_config(seed_dir: Path) -> dict:
    """
    Prefer:
      seed_dir/config.json
    Fallback:
      seed_dir/../config.json   (i.e., run_dir/config.json)
    """
    p1 = seed_dir / "config.json"
    if p1.exists():
        return read_json(p1)
    p2 = seed_dir.parent / "config.json"
    if p2.exists():
        return read_json(p2)
    raise FileNotFoundError(f"config.json not found under {seed_dir} or {seed_dir.parent}")


def list_seed_dirs(run_dir: Path) -> List[Path]:
    if not run_dir.exists():
        return []
    seeds = [p for p in run_dir.iterdir() if p.is_dir() and p.name.startswith("seed_")]
    return sorted(seeds, key=lambda p: p.name)


def find_checkpoints(seed_dir: Path) -> List[Path]:
    """
    Supports:
      seed_dir/checkpoints/*.pt
      seed_dir/checkpoints/**/*.pt
      seed_dir/*.pt
    """
    cands: List[Path] = []
    ckpt = seed_dir / "checkpoints"
    if ckpt.exists():
        cands.extend(list(ckpt.glob("*.pt")))
        cands.extend(list(ckpt.glob("**/*.pt")))
    cands.extend(list(seed_dir.glob("*.pt")))
    return sorted(set(cands), key=lambda p: p.as_posix())


def ckpt_epoch_from_payload_or_name(path: Path, payload: dict) -> Optional[int]:
    if isinstance(payload, dict) and "epoch" in payload:
        try:
            return int(payload["epoch"])
        except Exception:
            pass
    m = re.search(r"(\d+)", path.stem)
    if m:
        try:
            return int(m.group(1))
        except Exception:
            return None
    return None


# -------------------------
# Ragged cache helpers
# -------------------------
def load_seed_cache_if_compatible(
    cache_path: Path,
    *,
    seq_size: int,
    cutoff: float,
    eps: float,
    count_floor: float,
) -> Optional[Tuple[np.ndarray, np.ndarray, set[Tuple[int, ...]]]]:
    if not cache_path.exists():
        return None

    z = np.load(cache_path, allow_pickle=False)
    for k in ["epochs", "cum", "meta_seq_size", "meta_cutoff", "meta_eps", "meta_count_floor", "tokens", "offsets"]:
        if k not in z:
            return None

    meta_seq_size = int(z["meta_seq_size"])
    meta_cutoff = float(z["meta_cutoff"])
    meta_eps = float(z["meta_eps"])
    meta_count_floor = float(z["meta_count_floor"])

    if meta_seq_size != int(seq_size):
        return None
    if abs(meta_cutoff - float(cutoff)) > 1e-12:
        return None
    if abs(meta_eps - float(eps)) > 1e-12:
        return None
    if abs(meta_count_floor - float(count_floor)) > 1e-12:
        return None

    epochs = z["epochs"].astype(int, copy=False)
    cum = z["cum"].astype(float, copy=False)

    tokens = z["tokens"].astype(np.int32, copy=False)
    offsets = z["offsets"].astype(np.int32, copy=False)

    seen: set[Tuple[int, ...]] = set()
    if offsets.size > 1:
        for i in range(offsets.size - 1):
            a = int(offsets[i])
            b = int(offsets[i + 1])
            if b > a:
                seen.add(tuple(int(x) for x in tokens[a:b]))

    return epochs, cum, seen


def save_seed_cache(
    cache_path: Path,
    *,
    epochs: np.ndarray,
    cum: np.ndarray,
    seen: set[Tuple[int, ...]],
    seq_size: int,
    cutoff: float,
    eps: float,
    count_floor: float,
):
    seqs = sorted(seen)
    if len(seqs) == 0:
        tokens = np.zeros((0,), dtype=np.int16)
        offsets = np.zeros((1,), dtype=np.int32)
    else:
        lengths = np.array([len(s) for s in seqs], dtype=np.int32)
        offsets = np.zeros((len(seqs) + 1,), dtype=np.int32)
        offsets[1:] = np.cumsum(lengths)
        total = int(offsets[-1])
        tokens = np.zeros((total,), dtype=np.int16)
        k = 0
        for s in seqs:
            L = len(s)
            if L:
                tokens[k:k + L] = np.array(s, dtype=np.int16)
                k += L

    cache_path.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        cache_path,
        epochs=epochs.astype(int),
        cum=cum.astype(float),
        tokens=tokens,
        offsets=offsets,
        meta_seq_size=np.array(int(seq_size), dtype=np.int32),
        meta_cutoff=np.array(float(cutoff), dtype=np.float64),
        meta_eps=np.array(float(eps), dtype=np.float64),
        meta_count_floor=np.array(float(count_floor), dtype=np.float64),
    )


def cache_path_for_seed(seed_dir: Path, n_samples: int, good_threshold: float) -> Path:
    samples_dir = seed_dir / "samples"
    samples_dir.mkdir(parents=True, exist_ok=True)
    return samples_dir / f"unique_good_cache__n{n_samples}__thr{good_threshold}.npz"


# -------------------------
# Sampling helpers
# -------------------------
@torch.no_grad()
def forward_sampling(env: Sequences, forward_net: torch.nn.Module):
    for i in range(env.seq_size):
        active = env.alive.nonzero(as_tuple=True)[0]
        if active.numel() == 0:
            break
        s_sub = env.state.index_select(0, active)
        logits = forward_net(s_sub)
        actions = env.get_actions(logits, training=False)
        env.state[active, i] = actions
        env.alive[active] = (actions != 0)


def state_row_to_tuple(row: torch.Tensor) -> Tuple[int, ...]:
    arr = row.tolist()
    out = []
    for t in arr:
        t = int(t)
        if t == 0:
            break
        out.append(t)
    return tuple(out)


def pick_forward_state_dict(payload: dict) -> Dict[str, torch.Tensor]:
    banned = {"teacher_fnet", "div_fnet"}
    preferred = [
        "fnet",
        "student_fnet",
        "st_fnet",
        "forward_fnet",
        "pf",
        "pf_net",
        "policy",
        "model",
        "sa_fnet",
    ]
    for k in preferred:
        if k in banned:
            continue
        v = payload.get(k, None)
        if isinstance(v, dict):
            return v

    present_dict_keys = [k for k, v in payload.items() if isinstance(v, dict)]
    if any(k in present_dict_keys for k in banned) and all(k not in present_dict_keys for k in preferred):
        raise KeyError(
            "Checkpoint contains only banned sampling nets (teacher_fnet/div_fnet) "
            f"and no allowed net. dict keys present: {present_dict_keys}"
        )

    for k, v in payload.items():
        if k in banned:
            continue
        if isinstance(v, dict) and len(v) > 0:
            any_key = next(iter(v.keys()))
            if isinstance(any_key, str) and ("." in any_key or "weight" in any_key or "bias" in any_key):
                return v

    raise KeyError(f"Could not find forward policy state dict in payload keys: {list(payload.keys())}")


def build_policy_from_cfg(cfg: dict, device: str) -> Policy:
    kw = {}
    for k in ["emb_dim", "hidden", "pos_dim", "window"]:
        if k in cfg and cfg[k] is not None:
            kw[k] = int(cfg[k])
    try:
        return Policy(**kw).to(device)
    except TypeError:
        return Policy().to(device)


def make_sampling_env(cfg: dict, seq_size: int, n_samples: int, seed: int) -> Sequences:
    cutoff = float(cfg.get("cut_off", cfg.get("cutoff", 0.94)))
    eps = float(cfg.get("eps", 0.0))
    log_reward = LogReward(cutoff=cutoff)
    return Sequences(
        seq_size=seq_size,
        batch_size=n_samples,
        log_reward=log_reward,
        eps=eps,
        seed=seed,
    )


# -------------------------
# Per-seed curve with cache
# -------------------------
def compute_curve_for_seed_fast_cached(
    seed_dir: Path,
    *,
    n_samples: int,
    good_threshold: float,
    count_floor: float,
    ckpt_stride: int = 1,
    refresh_cache: bool = False,
    verbose: bool = False,
) -> Tuple[str, np.ndarray, np.ndarray]:
    cfg_json = get_config(seed_dir)
    if "seq_size" not in cfg_json or cfg_json["seq_size"] is None:
        raise ValueError(f"config.json in {seed_dir} missing seq_size")
    seq_size = int(cfg_json["seq_size"])

    cutoff = float(cfg_json.get("cut_off", cfg_json.get("cutoff", 0.94)))
    eps = float(cfg_json.get("eps", 0.0))

    ckpts = find_checkpoints(seed_dir)
    if not ckpts:
        return seed_dir.name, np.array([]), np.array([])

    items: List[Tuple[int, Path]] = []
    for p in ckpts:
        try:
            payload = torch.load(p, map_location="cpu")
        except Exception:
            continue
        ep = ckpt_epoch_from_payload_or_name(p, payload)
        if ep is None:
            continue
        items.append((ep, p))
    items.sort(key=lambda t: t[0])

    if not items:
        return seed_dir.name, np.array([]), np.array([])

    if ckpt_stride is None or ckpt_stride < 1:
        ckpt_stride = 1
    if ckpt_stride > 1:
        items = items[::ckpt_stride]

    cache_path = cache_path_for_seed(seed_dir, n_samples=n_samples, good_threshold=good_threshold)

    cached_epochs: np.ndarray = np.array([], dtype=int)
    cached_cum: np.ndarray = np.array([], dtype=float)
    seen: set[Tuple[int, ...]] = set()
    last_cached_epoch: Optional[int] = None

    if (not refresh_cache) and cache_path.exists():
        loaded = load_seed_cache_if_compatible(
            cache_path,
            seq_size=seq_size,
            cutoff=cutoff,
            eps=eps,
            count_floor=count_floor,
        )
        if loaded is not None:
            cached_epochs, cached_cum, seen = loaded
            if cached_epochs.size > 0:
                last_cached_epoch = int(cached_epochs[-1])

    start_idx = 0
    if last_cached_epoch is not None:
        for i, (ep, _) in enumerate(items):
            if ep > last_cached_epoch:
                start_idx = i
                break
        else:
            return seed_dir.name, cached_epochs, cached_cum

    device = cfg_json.get("device", "cpu")
    seed = int(cfg_json.get("seed", 0))

    env = make_sampling_env(cfg_json, seq_size=seq_size, n_samples=n_samples, seed=seed)
    net = build_policy_from_cfg(cfg_json, device=device)
    net.eval()

    epochs: List[int] = list(cached_epochs.tolist()) if cached_epochs.size > 0 else []
    cum: List[float] = list(cached_cum.tolist()) if cached_cum.size > 0 else []

    t0 = time.time()

    for k, (ep, ckpt_path) in enumerate(items[start_idx:], start=start_idx):
        payload = torch.load(ckpt_path, map_location="cpu")
        net.load_state_dict(pick_forward_state_dict(payload))
        net.eval()

        env.reset()
        forward_sampling(env, net)

        logR = env.log_reward()
        good_idx = (logR >= good_threshold).nonzero(as_tuple=True)[0]

        if good_idx.numel() > 0:
            states_good = env.state.index_select(0, good_idx)
            for j in range(states_good.shape[0]):
                tup = state_row_to_tuple(states_good[j])
                if tup:
                    seen.add(tup)

        epochs.append(ep)
        cum.append(float(len(seen)))

        if verbose and ((k - start_idx) % 10 == 0 or k == len(items) - 1):
            dt = time.time() - t0
            print(f"[{seed_dir.name}] {k+1}/{len(items)} ckpts | last_epoch={ep} | cum={len(seen)} | {dt:.1f}s")

    epochs_arr = np.array(epochs, dtype=int)
    cum_arr = np.array(cum, dtype=float)

    save_seed_cache(
        cache_path,
        epochs=epochs_arr,
        cum=cum_arr,
        seen=seen,
        seq_size=seq_size,
        cutoff=cutoff,
        eps=eps,
        count_floor=count_floor,
    )

    return seed_dir.name, epochs_arr, cum_arr


def _worker_compute_seed(
    seed_dir_str: str,
    n_samples: int,
    good_threshold: float,
    count_floor: float,
    ckpt_stride: int,
    refresh_cache: bool,
    verbose: bool,
):
    seed_dir = Path(seed_dir_str)
    return compute_curve_for_seed_fast_cached(
        seed_dir,
        n_samples=n_samples,
        good_threshold=good_threshold,
        count_floor=count_floor,
        ckpt_stride=ckpt_stride,
        refresh_cache=refresh_cache,
        verbose=verbose,
    )


# -------------------------
# Aggregation (log-symmetric bands)
# -------------------------
def align_and_aggregate_curves(
    curves: List[Tuple[np.ndarray, np.ndarray]],
    *,
    log_space: bool = False,
    count_floor: float = 1.0,
):
    all_epochs = sorted(set().union(*[set(e.tolist()) for e, _ in curves if e.size > 0]) if curves else set())
    if not all_epochs:
        z = np.array([])
        return z, z, z, z, z, z

    buckets: Dict[int, List[float]] = {ep: [] for ep in all_epochs}
    for e, v in curves:
        m = {int(ep): float(val) for ep, val in zip(e, v)}
        for ep in all_epochs:
            if ep in m:
                buckets[ep].append(m[ep])

    epochs, mean, std, vmin, vmax, count = [], [], [], [], [], []
    for ep in all_epochs:
        vals = np.array(buckets[ep], dtype=float)
        if vals.size == 0:
            continue

        if log_space:
            if count_floor <= 0:
                raise ValueError(f"count_floor must be > 0 when log_space=True, got {count_floor}")
            vals = np.log10(np.maximum(vals, count_floor))

        epochs.append(ep)
        mean.append(vals.mean())
        std.append(vals.std(ddof=0))
        vmin.append(vals.min())
        vmax.append(vals.max())
        count.append(vals.size)

    return (
        np.array(epochs),
        np.array(mean),
        np.array(std),
        np.array(vmin),
        np.array(vmax),
        np.array(count),
    )


# -------------------------
# Args
# -------------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--run_id", required=True, type=str)
    p.add_argument("--root", default="runs", type=str)
    p.add_argument("--exp", default="peptide", type=str)

    p.add_argument("--n_samples", default=1000, type=int)
    p.add_argument("--good_threshold", default=0.0, type=float, help="threshold on log_reward")

    p.add_argument("--band", default="std", choices=["std", "minmax"])
    p.add_argument("--sigma", default=1.0, type=float)

    p.add_argument("--y_scale", default="log10", choices=["log10", "linear"])
    p.add_argument("--count_floor", default=1.0, type=float, help="pseudocount for log-space stats when y_scale=log10")

    p.add_argument("--min_epoch", default=None, type=int)
    p.add_argument("--max_epoch", default=None, type=int)

    p.add_argument("--workers", default=6, type=int, help="0 = no multiprocessing; >0 parallelize across seeds")
    p.add_argument("--seed_verbose", action="store_true", help="print per-seed progress (useful if no tqdm)")

    p.add_argument("--ckpt_stride", default=1, type=int, help="sample every k-th checkpoint (>=1)")
    p.add_argument("--refresh_cache", action="store_true", help="ignore saved samples and re-sample everything")

    # smoothing (plot only)
    p.add_argument("--smooth", action="store_true", help="apply Savitzky-Golay smoothing to plotted curves")
    p.add_argument("--smooth_window", default=21, type=int, help="odd window length for SavGol")
    p.add_argument("--smooth_poly", default=3, type=int, help="polyorder for SavGol")
    p.add_argument("--smooth_monotone", action="store_true",
                   help="enforce non-decreasing after smoothing (useful for cumulative curves)")

    # exporting
    p.add_argument("--out", default=None, type=str, help="path for main figure (png/pdf inferred by extension)")
    p.add_argument("--legend_out", default=None, type=str, help="path for legend-only figure (png/pdf)")
    p.add_argument("--legend_ncol", default=2, type=int)
    p.add_argument("--legend_fontsize", default=10, type=int)

    return p.parse_args()


# -------------------------
# Main
# -------------------------
def main():
    args = parse_args()

    if args.workers and args.workers > 0:
        torch.set_num_threads(1)

    root = Path(args.root)
    exp_dir = root / args.exp
    if not exp_dir.exists():
        raise SystemExit(f"exp dir not found: {exp_dir}")

    # batch_size from peptide/experiments.toml for THIS run_id
    # (x-axis: trajectories = epochs * batch_size)
    try:
        toml_path = find_experiments_toml()
        batch_size = load_batch_size_from_toml(toml_path, args.run_id)
    except Exception as e:
        batch_size = 1
        print(f"[WARN] Could not load batch_size from peptide/experiments.toml for run_id={args.run_id}. "
              f"Falling back to batch_size=1. Error: {e}")

    method_dirs = sorted([p for p in exp_dir.iterdir() if p.is_dir()])
    method_iter = method_dirs if tqdm is None else tqdm(method_dirs, desc="methods", leave=True)

    fig, ax = plt.subplots()
    any_plotted = False

    legend_handles = []
    legend_labels = []

    for method_dir in method_iter:
        print(method_dir)
        field = method_dir.name  # dtb, tb, sa, teacher_student, ...
        run_dir = method_dir / args.run_id
        if not run_dir.exists():
            continue

        seed_dirs = list_seed_dirs(run_dir)
        if not seed_dirs:
            seed_dirs = [run_dir]

        curves: List[Tuple[np.ndarray, np.ndarray]] = []

        if args.workers and args.workers > 0 and len(seed_dirs) > 1:
            n_workers = min(args.workers, len(seed_dirs), os.cpu_count() or args.workers)
            futures = []
            with ProcessPoolExecutor(max_workers=n_workers) as ex:
                for sd in seed_dirs:
                    futures.append(
                        ex.submit(
                            _worker_compute_seed,
                            sd.as_posix(),
                            args.n_samples,
                            args.good_threshold,
                            args.count_floor,
                            args.ckpt_stride,
                            args.refresh_cache,
                            args.seed_verbose,
                        )
                    )

                fut_iter = as_completed(futures)
                if tqdm is not None:
                    fut_iter = tqdm(fut_iter, total=len(futures), desc=f"{field}: seeds", leave=False)
                else:
                    print(f"\n== {field}: {len(seed_dirs)} seeds (workers={n_workers}) ==")

                for fut in fut_iter:
                    _, e, c = fut.result()
                    if e.size > 0:
                        curves.append((e, c))
        else:
            seed_iter = seed_dirs if tqdm is None else tqdm(seed_dirs, desc=f"{field}: seeds", leave=False)
            if tqdm is None:
                print(f"\n== {field}: {len(seed_dirs)} seeds ==")

            for sd in seed_iter:
                _, e, c = compute_curve_for_seed_fast_cached(
                    sd,
                    n_samples=args.n_samples,
                    good_threshold=args.good_threshold,
                    count_floor=args.count_floor,
                    ckpt_stride=args.ckpt_stride,
                    refresh_cache=args.refresh_cache,
                    verbose=args.seed_verbose,
                )
                if e.size > 0:
                    curves.append((e, c))

        if not curves:
            continue

        use_log_band = (args.y_scale == "log10")
        epochs, mean, std, vmin, vmax, count = align_and_aggregate_curves(
            curves,
            log_space=use_log_band,
            count_floor=args.count_floor,
        )
        if epochs.size == 0:
            continue

        if args.min_epoch is not None:
            m = epochs >= args.min_epoch
            epochs, mean, std, vmin, vmax, count = epochs[m], mean[m], std[m], vmin[m], vmax[m], count[m]
        if args.max_epoch is not None:
            m = epochs <= args.max_epoch
            epochs, mean, std, vmin, vmax, count = epochs[m], mean[m], std[m], vmin[m], vmax[m], count[m]

        if epochs.size == 0:
            continue

        # x-axis: trajectories
        x = epochs.astype(np.int64) * int(batch_size)

        label_name = field_to_name.get(field, field)
        color = field_to_color.get(field, None)

        if args.y_scale == "log10":
            # Optional smoothing in log10-space (plot-only)
            if args.smooth:
                mean = smooth_curve(mean, window=args.smooth_window, poly=args.smooth_poly, monotone=False)
                if args.band == "std":
                    std = smooth_curve(std, window=args.smooth_window, poly=args.smooth_poly, monotone=False, floor=0.0)
                else:
                    vmin = smooth_curve(vmin, window=args.smooth_window, poly=args.smooth_poly, monotone=False)
                    vmax = smooth_curve(vmax, window=args.smooth_window, poly=args.smooth_poly, monotone=False)

            # convert to linear for plotting on log axis
            center = 10 ** mean
            if args.band == "std":
                low = 10 ** (mean - args.sigma * std)
                high = 10 ** (mean + args.sigma * std)
            else:
                low = 10 ** vmin
                high = 10 ** vmax

            # enforce floors and optional monotonicity in linear space
            center = np.maximum(center, args.count_floor)
            low = np.maximum(low, args.count_floor)
            high = np.maximum(high, args.count_floor)

            if args.smooth and args.smooth_monotone:
                center = np.maximum.accumulate(center)
                low = np.maximum.accumulate(low)
                high = np.maximum.accumulate(high)

            line_kwargs = {"label": rf"{label_name}"}
            if color is not None:
                line_kwargs["color"] = color
            (line,) = ax.plot(x, center, **line_kwargs)
            fill_color = line.get_color()
            ax.fill_between(x, low, high, alpha=0.2, color=fill_color)
        else:
            # Optional smoothing in linear space (plot-only)
            if args.smooth:
                mean = smooth_curve(
                    mean,
                    window=args.smooth_window,
                    poly=args.smooth_poly,
                    monotone=args.smooth_monotone,
                    floor=0.0,
                )
                if args.band == "std":
                    std = smooth_curve(std, window=args.smooth_window, poly=args.smooth_poly, monotone=False, floor=0.0)
                else:
                    vmin = smooth_curve(vmin, window=args.smooth_window, poly=args.smooth_poly,
                                        monotone=args.smooth_monotone, floor=0.0)
                    vmax = smooth_curve(vmax, window=args.smooth_window, poly=args.smooth_poly,
                                        monotone=args.smooth_monotone, floor=0.0)

            line_kwargs = {"label": rf"{label_name}"}
            if color is not None:
                line_kwargs["color"] = color
            (line,) = ax.plot(x, mean, **line_kwargs)
            fill_color = line.get_color()

            if args.band == "std":
                low = mean - args.sigma * std
                high = mean + args.sigma * std
            else:
                low, high = vmin, vmax

            if args.smooth and args.smooth_monotone:
                low = np.maximum.accumulate(low)
                high = np.maximum.accumulate(high)

            ax.fill_between(x, low, high, alpha=0.2, color=fill_color)

        legend_handles.append(line)
        legend_labels.append(rf"{label_name}")
        any_plotted = True

    if not any_plotted:
        raise SystemExit(
            f"No curves plotted. Ensure checkpoints exist under runs/{args.exp}/<method>/{args.run_id}/seed_*/checkpoints/"
        )

    ax.set_xlabel(r"Sampled trajectories")
    ax.xaxis.set_major_formatter(FuncFormatter(millions))
    ax.xaxis.get_offset_text().set_visible(False)

    if args.y_scale == "log10":
        ax.set_ylabel(rf"cumulative unique effective peptides")
        try:
            ax.set_yscale("log", base=10)
        except TypeError:
            ax.set_yscale("log")
    else:
        ax.set_ylabel(rf"cumulative unique effective peptides$")

    apply_format(ax)  # sets both axes
    ax.xaxis.set_major_formatter(FuncFormatter(millions))  # override x back to "M"
    ax.xaxis.get_offset_text().set_visible(False)

    fig.tight_layout()

    # -------------------------
    # Save main fig + legend fig
    # -------------------------
    if args.out is None:
        out_dir = root / args.exp / "plots"
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / f"unique_good_cum_{args.band}__{args.run_id}__yscale{args.y_scale}__traj.png"
    else:
        out_path = Path(args.out)
        out_path.parent.mkdir(parents=True, exist_ok=True)

    fig.savefig(out_path, dpi=300, bbox_inches="tight", pad_inches=0.0)
    plt.close(fig)
    print(f"Saved main fig: {out_path}")

    if args.legend_out is None:
        leg_path = out_path.with_name(f"legend__{out_path.stem}{out_path.suffix}")
    else:
        leg_path = Path(args.legend_out)
        leg_path.parent.mkdir(parents=True, exist_ok=True)

    save_legend(
        legend_handles,
        legend_labels,
        leg_path,
        ncol=4,
        fontsize=args.legend_fontsize,
    )
    print(f"Saved legend: {leg_path}")

    if batch_size != 1:
        print(f"Using batch_size={batch_size} (from {toml_path})")
    else:
        print("Using batch_size=1 (fallback)")

    if args.smooth:
        if savgol_filter is None:
            print("[WARN] --smooth requested but scipy is not available. Curves were not smoothed.")
        else:
            print(f"Applied SavGol smoothing: window={args.smooth_window}, poly={args.smooth_poly}, "
                  f"monotone={args.smooth_monotone}")


if __name__ == "__main__":
    main()