"""Build sample-by-sample trajectories for each method and benchmark.

For every sample (eval) in a run we record:
    (method, benchmark, sample_idx, latency_ms_or_None, failure_type)

From these we can compute best-so-far latency per benchmark, then normalize
by the XLA baseline to get speedup, then aggregate (geomean) across
benchmarks to produce a single trajectory per method.

Also emits an aggregate failure-type breakdown per method, averaged
over benchmarks.

Usage:
    python -m autocomp.baselines.trajectory \
        --iterative output/baselines/iterative \
        --best-of-n output/baselines/best_of_n \
        --autocomp  output/jaxbench-sweep \
        --out       /path/to/paper-figures/figures

Outputs in `--out`:
    trajectory.pdf          geomean speedup vs cumulative samples
    failures.pdf            stacked failure-type bar per method
    trajectory_data.json    raw per-sample data for reproducibility
"""
from __future__ import annotations

import argparse
import json
import math
import pathlib
import re
from collections import Counter, defaultdict

from autocomp.baselines.classify_failures import _classify_result, _load_result_json


# ---------------------------------------------------------------------------
# Per-method, per-benchmark collectors
# ---------------------------------------------------------------------------

# Each returns list[(latency_or_none, failure_type)] ordered by sample index.

def collect_iterative(run_dir: pathlib.Path) -> tuple[list[tuple[float | None, str]], float | None]:
    # Order samples turn-major, then chain, so the trajectory reflects how
    # chains actually advance in parallel: (turn 0, chains 0..17),
    # (turn 1, chains 0..17), ... This gives granularity 1 instead of a
    # chain-by-chain stairstep.
    per_turn: dict[int, list[tuple[int, tuple[float | None, str]]]] = {}
    for chain_dir in sorted(run_dir.glob("chain_*")):
        chain_idx = int(re.search(r"chain_(\d+)", chain_dir.name).group(1))
        for turn_dir in sorted(chain_dir.glob("turn_*")):
            turn_idx = int(re.search(r"turn_(\d+)", turn_dir.name).group(1))
            rp = turn_dir / "result.json"
            if not rp.exists():
                continue
            res = _load_result_json(rp) or {}
            lbl = _classify_result(res)
            lat = res.get("latency") if lbl == "success" else None
            per_turn.setdefault(turn_idx, []).append((chain_idx, (lat, lbl)))
    samples: list[tuple[float | None, str]] = []
    for turn_idx in sorted(per_turn):
        for _chain_idx, sample in sorted(per_turn[turn_idx]):
            samples.append(sample)
    baseline = _read_baseline(run_dir / "summary.json")
    return samples, baseline


def collect_best_of_n(run_dir: pathlib.Path) -> tuple[list[tuple[float | None, str]], float | None]:
    samples: list[tuple[float | None, str]] = []
    eval_dir = run_dir / "eval"
    if eval_dir.is_dir():
        # Sort by numeric index in filename to preserve generation order
        paths = sorted(
            (p for p in eval_dir.glob("code_*_result.txt") if "_full" not in p.name),
            key=lambda p: int(re.search(r"code_(\d+)_", p.name).group(1))
        )
        for p in paths:
            res = _load_result_json(p) or {}
            lbl = _classify_result(res)
            lat = res.get("latency") if lbl == "success" else None
            samples.append((lat, lbl))
    baseline = _read_baseline(run_dir / "summary.json")
    return samples, baseline


def collect_autocomp(run_dir: pathlib.Path) -> tuple[list[tuple[float | None, str]], float | None]:
    """Autocomp stores results under eval-results-iter-*/code_*_result.txt.

    We treat all evaluations (across the translate and optimize phases and
    across iterations) as a single ordered stream in generation order.
    """
    samples: list[tuple[float | None, str]] = []
    iter_dirs = sorted(run_dir.glob("eval-results-iter-*"),
                       key=lambda p: int(re.search(r"iter-(\d+)", p.name).group(1)))
    for it_dir in iter_dirs:
        paths = sorted(
            (p for p in it_dir.glob("code_*_result.txt") if "_full" not in p.name),
            key=lambda p: int(re.search(r"code_(\d+)_", p.name).group(1))
        )
        for p in paths:
            res = _load_result_json(p) or {}
            lbl = _classify_result(res)
            lat = res.get("latency") if lbl == "success" else None
            samples.append((lat, lbl))
    # Autocomp doesn't write baseline_latency_ms into its own metadata, so we
    # fall back to the XLA baseline from the matching baselines run, or the
    # iter-0 candidate score.
    baseline = _read_autocomp_baseline(run_dir)
    return samples, baseline


def _read_baseline(summary_path: pathlib.Path) -> float | None:
    if not summary_path.exists():
        return None
    try:
        return json.loads(summary_path.read_text()).get("baseline_latency_ms")
    except Exception:
        return None


def _read_autocomp_baseline(run_dir: pathlib.Path) -> float | None:
    # Use score=<ms> from the initial candidate (iter 0, corresponds to the
    # unmodified XLA reference). Fall back to min latency in iter-1.
    cand0 = run_dir / "candidates-iter-0" / "candidate_0.txt"
    if cand0.exists():
        m = re.search(r"^score=([\d.]+)", cand0.read_text(), flags=re.MULTILINE)
        if m:
            return float(m.group(1))
    return None


# ---------------------------------------------------------------------------
# Aggregation
# ---------------------------------------------------------------------------

def _best_so_far(latencies: list[float | None]) -> list[float | None]:
    out: list[float | None] = []
    best: float | None = None
    for lat in latencies:
        if lat is not None and (best is None or lat < best):
            best = lat
        out.append(best)
    return out


def speedup_trajectory(
    method_runs: dict[str, list[tuple[float | None, str]]],
    method_baselines: dict[str, float | None],
    max_samples: int | None = None,
) -> tuple[list[int], list[float]]:
    """Compute geomean-speedup trajectory across benchmarks.

    Extends all benchmarks to `max_samples` (or, if None, the max length of
    any benchmark under this method) by carrying forward each benchmark's
    best-so-far latency once it runs out of samples. Speedup at step t =
    baseline / best_latency_so_far[t]; if no success yet for a benchmark at
    step t, that benchmark contributes 1.0 to the geomean. This lets us
    show Autocomp's line continuing past its early-stopped sample count on
    the same x-axis as the baselines' full 144-sample budget.
    """
    benches = sorted(method_runs.keys())
    if not benches:
        return [], []

    n_method = max(len(method_runs[b]) for b in benches)
    n = max(n_method, max_samples) if max_samples is not None else n_method
    xs = list(range(0, n + 1))
    ys: list[float] = [1.0]

    for t in range(n):
        log_sum = 0.0
        count = 0
        for b in benches:
            baseline = method_baselines.get(b)
            if not baseline:
                continue
            # Cap index at this benchmark's length; once exhausted,
            # best-so-far is fixed and carries forward.
            effective_t = min(t + 1, len(method_runs[b]))
            lats = [s[0] for s in method_runs[b][:effective_t]]
            best = None
            for lat in lats:
                if lat is not None and (best is None or lat < best):
                    best = lat
            speedup = baseline / best if best and best > 0 else 1.0
            # Cap "slower than XLA" at 1x: the trajectory tracks cumulative
            # improvement, and a benchmark whose best correct sample is still
            # slower than XLA doesn't count as progress (we'd take XLA over it).
            if speedup < 1.0:
                speedup = 1.0
            log_sum += math.log(speedup)
            count += 1
        ys.append(math.exp(log_sum / count) if count else 1.0)

    return xs, ys


def fast_trajectory(
    method_runs: dict[str, list[tuple[float | None, str]]],
    method_baselines: dict[str, float | None],
    max_samples: int | None = None,
) -> tuple[list[int], list[float]]:
    """Compute fast@N: fraction of benchmarks where best-so-far beats XLA.

    At each cumulative sample step t, a benchmark counts as "fast" if its
    best correct sample so far is strictly faster than the XLA baseline.
    Returns (xs, ys) where ys[t] in [0, 1]. Monotonically non-decreasing.
    """
    benches = sorted(method_runs.keys())
    if not benches:
        return [], []

    n_method = max(len(method_runs[b]) for b in benches)
    n = max(n_method, max_samples) if max_samples is not None else n_method
    xs = list(range(0, n + 1))
    ys: list[float] = [0.0]

    for t in range(n):
        fast = 0
        count = 0
        for b in benches:
            baseline = method_baselines.get(b)
            if not baseline:
                continue
            count += 1
            effective_t = min(t + 1, len(method_runs[b]))
            best = None
            for lat, _lbl in method_runs[b][:effective_t]:
                if lat is not None and (best is None or lat < best):
                    best = lat
            if best and best > 0 and baseline / best > 1.0:
                fast += 1
        ys.append(fast / count if count else 0.0)

    return xs, ys


def failure_breakdown(
    method_runs: dict[str, list[tuple[float | None, str]]],
) -> dict[str, float]:
    """Return {failure_type: fraction} pooled across all samples.

    Sample counts can vary per benchmark (e.g. Autocomp's early-stopping),
    so we pool every sample into a single count rather than averaging
    per-benchmark rates. This matches the sample-count fractions reported
    in the results table (e.g. 265/720).
    """
    c: Counter = Counter()
    total = 0
    for samples in method_runs.values():
        for _, lbl in samples:
            c[lbl] += 1
            total += 1
    if total == 0:
        return {}
    return {k: v / total for k, v in c.items()}


# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------

# Group small failure categories under a small fixed palette.
FAILURE_BUCKETS = [
    ("success",         ["success"]),
    ("wrong output",    ["correctness_error"]),
    ("OOM",             ["oom"]),
    ("API / runtime",   ["pallas_api_error", "pallas_unsupported",
                         "compile_error", "runtime_error",
                         "import_error", "missing_workload",
                         "no_code_extracted", "eval_exception",
                         "no_output", "timeout"]),
]


def _bucketize(breakdown: dict[str, float]) -> dict[str, float]:
    out = {name: 0.0 for name, _ in FAILURE_BUCKETS}
    for k, v in breakdown.items():
        placed = False
        for name, members in FAILURE_BUCKETS:
            if k in members:
                out[name] += v
                placed = True
                break
        if not placed:
            out["API / runtime"] += v
    return out


def plot(trajectories, failures, out_dir: pathlib.Path,
         per_bench_speedups: dict[str, dict[str, float]] | None = None,
         method_x_ends: dict[str, int] | None = None,
         fast_trajectories: dict[str, tuple[list[int], list[float]]] | None = None):
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    out_dir.mkdir(parents=True, exist_ok=True)

    # Okabe-Ito inspired, softened to Google-style tones.
    colors = {"Best-of-N":     "#EA4335",   # Google red
              "Iterative":     "#4285F4",   # Google blue
              "Iterative+ctx": "#FBBC04",   # Google yellow
              "Autocomp":      "#34A853"}   # Google green
    method_order = ["Best-of-N", "Iterative", "Iterative+ctx", "Autocomp"]

    def _plot_traj(ax, data, ylabel, ylim=None, hline=None, legend=True):
        for name in method_order:
            if name not in data:
                continue
            xs, ys = data[name]
            end = method_x_ends.get(name) if method_x_ends else None
            if end is not None:
                xs = xs[: end + 1]
                ys = ys[: end + 1]
            ax.plot(xs, ys, label=name, color=colors.get(name), linewidth=1.8)
        if hline is not None:
            ax.axhline(hline, color="gray", linestyle="--", linewidth=0.8, alpha=0.6)
        ax.set_xlabel("Cumulative samples")
        ax.set_ylabel(ylabel)
        if ylim is not None:
            ax.set_ylim(*ylim)
        if legend:
            ax.legend(loc="lower right", frameon=False, fontsize=9)
        ax.grid(True, alpha=0.3, linewidth=0.5)

    # --- standalone trajectory (speedup only, kept for backwards compat) ---
    fig, ax = plt.subplots(figsize=(5.5, 3.2))
    _plot_traj(ax, trajectories, "Geomean speedup over XLA", hline=1.0)
    fig.tight_layout()
    fig.savefig(out_dir / "trajectory.pdf")
    plt.close(fig)

    # --- side-by-side speedup + fast@N ---
    if fast_trajectories:
        fig, (axL, axR) = plt.subplots(1, 2, figsize=(9.5, 3.2))
        _plot_traj(axL, trajectories, "Geomean speedup over XLA", hline=1.0, legend=False)
        # Render fast@N as percent. Shortened ylabel so it does not collide
        # with the shared legend strip above the axes.
        fast_pct = {m: (xs, [y * 100 for y in ys]) for m, (xs, ys) in fast_trajectories.items()}
        _plot_traj(axR, fast_pct, "fast@N (% beating XLA)",
                   ylim=(-5, 105), legend=False)
        # Shared legend above the axes, with enough headroom that neither
        # ylabel runs into it.
        handles = []
        labels = []
        for name in method_order:
            if name in trajectories:
                handles.append(plt.Line2D([0], [0], color=colors[name], linewidth=1.8))
                labels.append(name)
        fig.tight_layout(rect=[0, 0, 1, 0.93])
        fig.legend(handles, labels, loc="upper center", ncol=len(labels),
                   frameon=False, fontsize=9, bbox_to_anchor=(0.5, 0.99))
        fig.savefig(out_dir / "trajectory_fast.pdf", bbox_inches="tight")
        plt.close(fig)

    # --- stacked failure bar ---
    methods = [m for m in ["Best-of-N", "Iterative", "Iterative+ctx", "Autocomp"] if m in failures]
    bucket_names = [name for name, _ in FAILURE_BUCKETS]
    # Colors trace a most->least-correct spectrum: success (green) ->
    # wrong output (yellow, code ran but wrong) -> OOM (amber, compiled
    # but couldn't fit) -> API/runtime (red, never ran).
    palette = {
        "success":       "#34A853",   # Google green
        "wrong output":  "#FBBC04",   # Google yellow
        "OOM":           "#FA7B17",   # Google orange
        "API / runtime": "#EA4335",   # Google red
    }

    # Plot methods bottom-to-top in weakest->strongest order so the reader
    # sees Best-of-N at the top and Autocomp at the bottom (matplotlib barh
    # stacks from the bottom up).
    methods_plot = list(reversed(methods))
    fig, ax = plt.subplots(figsize=(5.5, 2.2))
    bottoms = [0.0] * len(methods_plot)
    for bname in bucket_names:
        vals = [failures[m].get(bname, 0.0) for m in methods_plot]
        ax.barh(methods_plot, vals, left=bottoms, color=palette[bname], label=bname,
                edgecolor="white", linewidth=0.6, height=0.5)
        bottoms = [b + v for b, v in zip(bottoms, vals)]
    ax.set_xlim(0, 1)
    ax.set_xlabel("Fraction of samples (averaged over benchmarks)")
    ax.tick_params(axis='y', labelsize=9)
    ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.3),
              ncol=4, frameon=False, fontsize=9)
    fig.subplots_adjust(left=0.18, bottom=0.35, top=0.95)
    fig.savefig(out_dir / "failures.pdf", bbox_inches="tight")
    plt.close(fig)

    # --- per-benchmark dot plot ---
    if per_bench_speedups:
        def _best_across_methods(bench: str) -> float:
            vals = [m.get(bench) for m in per_bench_speedups.values()]
            vals = [v for v in vals if v is not None]
            return max(vals) if vals else 1.0

        all_benches = sorted(
            {b for m in per_bench_speedups.values() for b in m},
            key=_best_across_methods,
        )
        method_order = [m for m in ["Best-of-N", "Iterative", "Iterative+ctx", "Autocomp"]
                        if m in per_bench_speedups]
        markers = {"Best-of-N": "o", "Iterative": "s", "Iterative+ctx": "^", "Autocomp": "D"}
        offsets = {"Best-of-N": 0.12, "Iterative": 0.04, "Iterative+ctx": -0.04, "Autocomp": -0.12}

        n_bench = len(all_benches)
        fig_h = max(2.0, 0.15 * n_bench + 0.8)
        fig, ax = plt.subplots(figsize=(5.5, fig_h))

        for method in method_order:
            xs, ys = [], []
            for i, bench in enumerate(all_benches):
                sp = per_bench_speedups[method].get(bench)
                if sp is not None:
                    xs.append(sp)
                    ys.append(i + offsets[method])
            ax.scatter(xs, ys, color=colors[method], marker=markers[method],
                       s=36, label=method, zorder=3, edgecolors="white",
                       linewidths=0.5, alpha=0.9)

        ax.axvline(1.0, color="gray", linestyle="--", linewidth=0.8, alpha=0.6)
        ax.set_yticks(range(n_bench))
        ax.set_ylim(-0.5, n_bench - 0.5)
        bench_labels = [b.split("_", 1)[1].replace("_", " ") if "_" in b else b
                        for b in all_benches]
        ax.set_yticklabels(bench_labels, fontsize=7)
        ax.set_xlabel("Best speedup over XLA")
        ax.legend(loc="lower right", frameon=False, fontsize=8)
        ax.grid(True, axis="x", alpha=0.3, linewidth=0.5)
        for i in range(n_bench):
            ax.axhline(i, color="gray", linewidth=0.3, alpha=0.4, zorder=0)
        fig.tight_layout()
        fig.savefig(out_dir / "per_benchmark.pdf", bbox_inches="tight")
        plt.close(fig)


# ---------------------------------------------------------------------------
# Driver
# ---------------------------------------------------------------------------

def _strip_autocomp_suffix(name: str) -> str:
    # 5p_Flex_Attention_baseline  -> 5p_Flex_Attention
    # 5p_Flex_Attention_baseline_translate -> skip (use optimize run)
    if name.endswith("_translate"):
        return ""
    for suf in ("_baseline", "_plans"):
        if name.endswith(suf):
            return name[: -len(suf)]
    return name


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--iterative", type=pathlib.Path)
    ap.add_argument("--iterative-context", type=pathlib.Path,
                    help="Iterative refinement with Autocomp context (iter+ctx ablation).")
    ap.add_argument("--best-of-n", type=pathlib.Path)
    ap.add_argument("--autocomp",  type=pathlib.Path)
    ap.add_argument("--out",       type=pathlib.Path, required=True)
    args = ap.parse_args()

    method_runs: dict[str, dict[str, list[tuple[float | None, str]]]] = defaultdict(dict)
    method_baselines: dict[str, dict[str, float | None]] = defaultdict(dict)

    def _ingest(method: str, root: pathlib.Path, collector, strip=lambda s: s):
        if not root or not root.is_dir():
            return
        for run_dir in sorted(p for p in root.iterdir() if p.is_dir()):
            bench = strip(run_dir.name)
            if not bench:
                continue
            if "smoke" in bench.lower():
                continue
            samples, baseline = collector(run_dir)
            if not samples:
                continue
            # merge duplicates (e.g. autocomp translate vs optimize both present)
            method_runs[method].setdefault(bench, []).extend(samples)
            if baseline and not method_baselines[method].get(bench):
                method_baselines[method][bench] = baseline

    _ingest("Iterative", args.iterative, collect_iterative)
    _ingest("Iterative+ctx", args.iterative_context, collect_iterative)
    _ingest("Best-of-N", args.best_of_n, collect_best_of_n)
    _ingest("Autocomp",  args.autocomp,  collect_autocomp, strip=_strip_autocomp_suffix)

    # Also allow autocomp _translate_ dirs to contribute samples; they share the
    # benchmark name after stripping.
    if args.autocomp and args.autocomp.is_dir():
        for run_dir in sorted(p for p in args.autocomp.iterdir() if p.is_dir()):
            if not run_dir.name.endswith("_translate"):
                continue
            bench = run_dir.name[: -len("_translate")].removesuffix("_baseline")
            samples, baseline = collect_autocomp(run_dir)
            if not samples:
                continue
            # Prepend translate-phase samples so optimize comes after.
            existing = method_runs["Autocomp"].get(bench, [])
            method_runs["Autocomp"][bench] = samples + existing
            if baseline and not method_baselines["Autocomp"].get(bench):
                method_baselines["Autocomp"][bench] = baseline

    # Report per-method coverage.
    print("=" * 60)
    print("Coverage:")
    for m in ["Best-of-N", "Iterative", "Iterative+ctx", "Autocomp"]:
        runs = method_runs.get(m, {})
        print(f"  {m:14s}: {len(runs)} benchmarks",
              ", ".join(sorted(runs.keys())))
    print("=" * 60)

    # Fill in baselines from the shared pool. We use the same XLA baseline
    # per benchmark for every method so speedups are apples-to-apples, even
    # if different methods measured XLA on slightly different days and got
    # slightly different numbers. Iterative's measurement is preferred.
    all_baselines: dict[str, float] = {}
    for m in ["Iterative", "Iterative+ctx", "Best-of-N", "Autocomp"]:
        for b, v in method_baselines.get(m, {}).items():
            if v and b not in all_baselines:
                all_baselines[b] = v
    for m in method_baselines:
        for b in method_runs[m]:
            method_baselines[m][b] = all_baselines.get(b)

    # Use only benchmarks covered by every method that has runs, so the
    # geomean is apples-to-apples.
    active_methods = [m for m in ["Best-of-N", "Iterative", "Iterative+ctx", "Autocomp"] if method_runs.get(m)]
    if active_methods:
        shared = set.intersection(*(set(method_runs[m].keys()) for m in active_methods))
    else:
        shared = set()
    print(f"Shared benchmarks (used for geomean): {sorted(shared)}")

    trajectories = {}
    fast_trajectories = {}
    failures = {}
    # Pad each method's trajectory to the global max across active methods so
    # early-stopped methods (e.g. Autocomp) are drawn on the same x-axis as
    # the full-budget baselines, but truncate each method's line at its own
    # real sample count (stored in x_end). Plotting code uses xs[:x_end+1].
    global_max = 0
    for m in active_methods:
        for b in method_runs[m]:
            if b in shared:
                global_max = max(global_max, len(method_runs[m][b]))
    method_x_ends: dict[str, int] = {}
    for m in active_methods:
        runs = {b: v for b, v in method_runs[m].items() if b in shared}
        baselines = {b: method_baselines[m].get(b) for b in runs}
        trajectories[m] = speedup_trajectory(runs, baselines, max_samples=global_max)
        fast_trajectories[m] = fast_trajectory(runs, baselines, max_samples=global_max)
        method_x_ends[m] = max((len(runs[b]) for b in runs), default=0)
        failures[m] = _bucketize(failure_breakdown(runs))

    # Compute per-benchmark final speedup for the dot plot.
    # This uses ALL benchmarks per method (not just the shared set).
    per_bench_speedups: dict[str, dict[str, float]] = {}
    for m in active_methods:
        per_bench_speedups[m] = {}
        for b, samples in method_runs[m].items():
            baseline = method_baselines[m].get(b)
            if not baseline:
                continue
            best_lat = None
            for lat, lbl in samples:
                if lat is not None and (best_lat is None or lat < best_lat):
                    best_lat = lat
            if best_lat and best_lat > 0:
                per_bench_speedups[m][b] = baseline / best_lat

    # Dump raw data.
    args.out.mkdir(parents=True, exist_ok=True)
    (args.out / "trajectory_data.json").write_text(json.dumps({
        "trajectories": trajectories,
        "fast_trajectories": fast_trajectories,
        "failures": failures,
        "baselines": all_baselines,
        "n_benchmarks": {m: len(r) for m, r in method_runs.items()},
        "per_bench_speedups": per_bench_speedups,
    }, indent=2))

    plot(trajectories, failures, args.out, per_bench_speedups=per_bench_speedups,
         method_x_ends=method_x_ends, fast_trajectories=fast_trajectories)
    print(f"Wrote figures to {args.out}")


if __name__ == "__main__":
    main()
