"""Produce a combined Gemini 3.1 Pro + Gemini 3 Flash trajectory figure.

Shares the speedup-trajectory computation with `trajectory.py` (same per-benchmark
speedup floor at 1x, same geomean aggregation). Pro lines are drawn solid; Flash
lines are drawn dashed in the same method-color.

Usage:
    python -m autocomp.baselines.trajectory_overlay \
        --pro-iterative         output/baselines/iterative \
        --pro-iterative-context output/baselines/iterative_context \
        --pro-best-of-n         output/baselines/best_of_n \
        --pro-autocomp          output/jaxbench-sweep-pro \
        --flash-iterative       output/baselines-flash/iterative \
        --flash-iterative-context output/baselines-flash/iterative_context \
        --flash-best-of-n       output/baselines-flash/best_of_n \
        --flash-autocomp        output/jaxbench-sweep-flash \
        --out                   /path/to/paper-figures/Figures

``--*-iterative-context`` is optional; omit to drop the curve from the overlay.
"""
from __future__ import annotations

import argparse
import pathlib
from collections import defaultdict

from autocomp.baselines.trajectory import (
    collect_iterative, collect_best_of_n, collect_autocomp,
    speedup_trajectory, fast_trajectory, _strip_autocomp_suffix,
)


def _build_trajectories(iterative_dir, iterative_context_dir, best_of_n_dir, autocomp_dir, max_samples=None):
    method_runs: dict[str, dict[str, list]] = defaultdict(dict)
    method_baselines: dict[str, dict[str, float | None]] = defaultdict(dict)

    def _ingest(method, root, collector, strip=lambda s: s):
        if not root or not root.is_dir():
            return
        for run_dir in sorted(p for p in root.iterdir() if p.is_dir()):
            bench = strip(run_dir.name)
            if not bench or "smoke" in bench.lower():
                continue
            samples, baseline = collector(run_dir)
            if not samples:
                continue
            method_runs[method].setdefault(bench, []).extend(samples)
            if baseline and not method_baselines[method].get(bench):
                method_baselines[method][bench] = baseline

    _ingest("Iterative",      iterative_dir,         collect_iterative)
    _ingest("Iterative+ctx",  iterative_context_dir, collect_iterative)
    _ingest("Best-of-N",      best_of_n_dir,         collect_best_of_n)
    _ingest("Autocomp",       autocomp_dir,          collect_autocomp, strip=_strip_autocomp_suffix)

    if autocomp_dir and autocomp_dir.is_dir():
        for run_dir in sorted(p for p in autocomp_dir.iterdir() if p.is_dir()):
            if not run_dir.name.endswith("_translate"):
                continue
            bench = run_dir.name[: -len("_translate")].removesuffix("_baseline")
            samples, baseline = collect_autocomp(run_dir)
            if not samples:
                continue
            existing = method_runs["Autocomp"].get(bench, [])
            method_runs["Autocomp"][bench] = samples + existing
            if baseline and not method_baselines["Autocomp"].get(bench):
                method_baselines["Autocomp"][bench] = baseline

    # Apples-to-apples baseline pool.
    all_baselines: dict[str, float] = {}
    for m in ["Iterative", "Iterative+ctx", "Best-of-N", "Autocomp"]:
        for b, v in method_baselines.get(m, {}).items():
            if v and b not in all_baselines:
                all_baselines[b] = v
    for m in list(method_baselines):
        for b in method_runs[m]:
            method_baselines[m][b] = all_baselines.get(b)

    # Restrict to the set of benchmarks every active method covers.
    active = [m for m in ["Best-of-N", "Iterative", "Iterative+ctx", "Autocomp"] if method_runs.get(m)]
    if not active:
        return {}, {}, {}
    shared = set.intersection(*(set(method_runs[m].keys()) for m in active))

    trajectories = {}
    fast_trajectories = {}
    method_x_ends: dict[str, int] = {}
    local_max = 0
    for m in active:
        for b in method_runs[m]:
            if b in shared:
                local_max = max(local_max, len(method_runs[m][b]))
    target = max(local_max, max_samples) if max_samples else local_max
    for m in active:
        runs = {b: v for b, v in method_runs[m].items() if b in shared}
        baselines = {b: method_baselines[m].get(b) for b in runs}
        trajectories[m] = speedup_trajectory(runs, baselines, max_samples=target)
        fast_trajectories[m] = fast_trajectory(runs, baselines, max_samples=target)
        method_x_ends[m] = max((len(runs[b]) for b in runs), default=0)
    return trajectories, fast_trajectories, method_x_ends


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--pro-iterative", type=pathlib.Path, required=True)
    ap.add_argument("--pro-iterative-context", type=pathlib.Path, default=None)
    ap.add_argument("--pro-best-of-n", type=pathlib.Path, required=True)
    ap.add_argument("--pro-autocomp",  type=pathlib.Path, required=True)
    ap.add_argument("--flash-iterative", type=pathlib.Path, required=True)
    ap.add_argument("--flash-iterative-context", type=pathlib.Path, default=None)
    ap.add_argument("--flash-best-of-n", type=pathlib.Path, required=True)
    ap.add_argument("--flash-autocomp",  type=pathlib.Path, required=True)
    ap.add_argument("--out", type=pathlib.Path, required=True)
    args = ap.parse_args()

    pro, _, _ = _build_trajectories(
        args.pro_iterative, args.pro_iterative_context,
        args.pro_best_of_n, args.pro_autocomp)
    flash, _, _ = _build_trajectories(
        args.flash_iterative, args.flash_iterative_context,
        args.flash_best_of_n, args.flash_autocomp)
    common_max = max(
        max((xs[-1] for xs, _ in pro.values()), default=0),
        max((xs[-1] for xs, _ in flash.values()), default=0),
    )
    pro, pro_fast, pro_ends = _build_trajectories(
        args.pro_iterative, args.pro_iterative_context,
        args.pro_best_of_n, args.pro_autocomp,
        max_samples=common_max)
    flash, flash_fast, flash_ends = _build_trajectories(
        args.flash_iterative, args.flash_iterative_context,
        args.flash_best_of_n, args.flash_autocomp,
        max_samples=common_max)

    print("Pro final speedups:")
    for m, (xs, ys) in pro.items():
        print(f"  {m:10s} {ys[-1]:.3f}x  (n={xs[-1]}, real_end={pro_ends[m]})")
    print("Flash final speedups:")
    for m, (xs, ys) in flash.items():
        print(f"  {m:10s} {ys[-1]:.3f}x  (n={xs[-1]}, real_end={flash_ends[m]})")

    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D

    colors = {
        "Best-of-N": "#EA4335",
        "Iterative": "#4285F4",
        "Iterative+ctx": "#FBBC04",
        "Autocomp": "#34A853",
    }
    method_order = ["Best-of-N", "Iterative", "Iterative+ctx", "Autocomp"]

    def _draw(ax, pro_data, flash_data, ylabel, ylim=None, hline=None):
        for name in method_order:
            if name in pro_data:
                xs, ys = pro_data[name]
                end = pro_ends.get(name, len(xs) - 1)
                xs, ys = xs[: end + 1], ys[: end + 1]
                ax.plot(xs, ys, color=colors[name], linewidth=1.8, linestyle="-")
            if name in flash_data:
                xs, ys = flash_data[name]
                end = flash_ends.get(name, len(xs) - 1)
                xs, ys = xs[: end + 1], ys[: end + 1]
                ax.plot(xs, ys, color=colors[name], linewidth=1.4,
                        linestyle="--", alpha=0.85)
        if hline is not None:
            ax.axhline(hline, color="gray", linestyle=":", linewidth=0.8, alpha=0.6)
        if ylim is not None:
            ax.set_ylim(*ylim)
        ax.set_xlabel("Cumulative samples")
        ax.set_ylabel(ylabel)
        ax.grid(True, alpha=0.3, linewidth=0.5)

    fig, (axL, axR) = plt.subplots(1, 2, figsize=(9.5, 3.2))
    _draw(axL, pro, flash, "Geomean speedup over XLA", hline=1.0)
    pro_fast_pct = {m: (xs, [y * 100 for y in ys]) for m, (xs, ys) in pro_fast.items()}
    flash_fast_pct = {m: (xs, [y * 100 for y in ys]) for m, (xs, ys) in flash_fast.items()}
    _draw(axR, pro_fast_pct, flash_fast_pct, "fast@N (% beating XLA)", ylim=(-5, 105))

    # Two separate legends so "model = linestyle" reads as an encoding, not
    # another baseline. Methods (color) on the left; model (linestyle) on
    # the right, both above the axes with small titles.
    method_handles = [Line2D([0], [0], color=colors[m], lw=1.8, label=m)
                      for m in method_order]
    model_handles = [
        Line2D([0], [0], color="black", lw=1.8, linestyle="-",  label="Gemini 3.1 Pro"),
        Line2D([0], [0], color="black", lw=1.4, linestyle="--", label="Gemini 3 Flash"),
    ]
    fig.tight_layout(rect=[0, 0, 1, 0.9])
    leg_methods = fig.legend(
        handles=method_handles, loc="upper center", ncol=len(method_handles),
        frameon=False, fontsize=9, bbox_to_anchor=(0.32, 0.99),
        title="Method", title_fontsize=9,
    )
    fig.add_artist(leg_methods)
    fig.legend(
        handles=model_handles, loc="upper center", ncol=len(model_handles),
        frameon=False, fontsize=9, bbox_to_anchor=(0.78, 0.99),
        title="Model", title_fontsize=9,
    )

    out_path = args.out / "trajectory_pro.pdf"
    fig.savefig(out_path, bbox_inches="tight")
    plt.close(fig)
    print(f"Wrote {out_path}")


if __name__ == "__main__":
    main()
