"""Generate MFU (MXU utilization) figures from trajectory_data.json.

Reads `trajectory_data.json` produced by `trajectory.py` and computes MFU for
each (method, benchmark) pair as:

    MFU = FLOPs / (best_latency_s * PEAK_BF16_FLOPS)

where FLOPs per benchmark are read from the corresponding JAXBench
`baseline.py` (so the numbers stay in sync if those ever change). The XLA
baseline MFU is computed the same way from its measured latency.

For benchmarks a method failed to solve, we fall back to the XLA baseline
MFU as a floor (rather than 0%), to avoid penalising methods for missing
benchmarks and over-penalising memory-bound kernels that bottom out near 0.

Produces:
    mfu.pdf        per-benchmark horizontal dot plot (XLA + one dot per method)
    mfu_dist.pdf   compact summary: MFU distribution strip plot per method

Usage (matching trajectory.py):
    python -m autocomp.baselines.mfu \
        --out /path/to/paper-figures/Figures

The `--out` directory must already contain `trajectory_data.json`. Optional:
    --jaxbench-root  path to JAXBench repo (for FLOPs lookup)
    --peak-tflops    peak bf16 TFLOPS for the target TPU (default 918, v6e)
"""
from __future__ import annotations

import argparse
import ast
import json
import math
import pathlib
from typing import Iterable

import matplotlib.pyplot as plt
import numpy as np


DEFAULT_PEAK_TFLOPS = 918.0  # TPU v6e bf16 peak, same number used in eval harness.

DEFAULT_JAXBENCH_ROOT = pathlib.Path("/path/to/JAXBench")

# Keep in sync with autocomp.baselines.trajectory — same palette across all
# baseline-comparison figures in the paper.
METHOD_COLORS = {
    "XLA":           "#5F6368",  # Google gray
    "Best-of-N":     "#EA4335",  # Google red
    "Iterative":     "#4285F4",  # Google blue
    "Iterative+ctx": "#FBBC04",  # Google yellow
    "Autocomp":      "#34A853",  # Google green
}

METHOD_MARKERS = {
    "XLA":           "x",
    "Best-of-N":     "o",
    "Iterative":     "s",
    "Iterative+ctx": "^",
    "Autocomp":      "D",
}


# ---------------------------------------------------------------------------
# FLOPs lookup
# ---------------------------------------------------------------------------

def _parse_config_dict(src: str) -> dict:
    """Extract the CONFIG dict literal from a JAXBench baseline.py."""
    try:
        tree = ast.parse(src)
    except SyntaxError:
        return {}
    for node in tree.body:
        if (isinstance(node, ast.Assign)
                and len(node.targets) == 1
                and isinstance(node.targets[0], ast.Name)
                and node.targets[0].id == "CONFIG"):
            try:
                return ast.literal_eval(node.value)
            except (ValueError, SyntaxError):
                return {}
    return {}


def _iter_assigns(tree: ast.AST) -> Iterable[ast.Assign]:
    """Yield every top-level Assign node inside the tree (including nested in
    function bodies), in source order."""
    for node in ast.walk(tree):
        if isinstance(node, ast.Assign):
            yield node


def load_flops_for_bench(bench_dir: pathlib.Path) -> float | None:
    """Parse bench_dir/baseline.py and return FLOPs as a float.

    Walks every assignment in the file in source order. For each assignment
    whose targets are plain names, we substitute any CONFIG-alias subscripts
    (`C['key']` when `C = CONFIG`) with the CONFIG value and evaluate the RHS
    in a restricted env seeded with CONFIG entries. The value of the final
    `flops = ...` assignment is returned. This naturally handles:
      - `B, S, D = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim']`
      - `C = CONFIG` followed by `B = C['batch']`
      - multi-line helper assignments like
          `proj_flops = (B * S * ql * 2 + ...)`
    """
    baseline_py = bench_dir / "baseline.py"
    if not baseline_py.is_file():
        return None
    src = baseline_py.read_text()
    cfg = _parse_config_dict(src)
    if not cfg:
        return None
    try:
        tree = ast.parse(src)
    except SyntaxError:
        return None

    env: dict[str, object] = dict(cfg)
    aliases: set[str] = {"CONFIG"}
    result: float | None = None

    class _Rewriter(ast.NodeTransformer):
        """Rewrite `alias['key']` -> CONFIG value constant."""
        def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
            self.generic_visit(node)
            if (isinstance(node.value, ast.Name)
                    and node.value.id in aliases):
                sl = node.slice
                if isinstance(sl, ast.Constant) and isinstance(sl.value, str):
                    if sl.value in cfg:
                        return ast.copy_location(ast.Constant(value=cfg[sl.value]), node)
            return node

    rewriter = _Rewriter()
    for node in _iter_assigns(tree):
        rhs = node.value
        # Detect `X = CONFIG` (alias creation) and record without evaluating.
        if (isinstance(rhs, ast.Name) and rhs.id in aliases
                and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name)):
            aliases.add(node.targets[0].id)
            continue
        new_rhs = rewriter.visit(ast.fix_missing_locations(ast.parse(
            ast.unparse(rhs), mode="eval"))).body  # type: ignore[attr-defined]
        try:
            value = eval(compile(ast.Expression(new_rhs), "<flops>", "eval"),
                         {"__builtins__": {}}, env)
        except Exception:
            continue
        targets = node.targets
        if len(targets) != 1:
            continue
        tgt = targets[0]
        if isinstance(tgt, ast.Name):
            env[tgt.id] = value
            if tgt.id == "flops":
                try:
                    result = float(value)
                except (TypeError, ValueError):
                    pass
        elif isinstance(tgt, ast.Tuple):
            try:
                seq = list(value)  # type: ignore[arg-type]
            except TypeError:
                continue
            if len(seq) != len(tgt.elts):
                continue
            for elt, v in zip(tgt.elts, seq):
                if isinstance(elt, ast.Name):
                    env[elt.id] = v

    return result


def _load_xla_flops_cache(jaxbench_root: pathlib.Path) -> dict[str, int]:
    """Load precomputed XLA cost_analysis() FLOPs from
    JAXBench/benchmark/flops_cache.json.

    That file is generated by JAXBench/benchmark/cache_flops.py on a JAX-capable
    machine. It provides a FLOPs number for every JAXBench workload (derived
    from the compiled XLA HLO graph) and is used as a fallback when the
    benchmark does not expose an analytical `flops = ...` formula we can
    extract from its source (the 33 KernelBench L2 workloads)."""
    cache_path = jaxbench_root / "benchmark" / "flops_cache.json"
    if not cache_path.is_file():
        return {}
    try:
        raw = json.loads(cache_path.read_text())
    except Exception:
        return {}
    return {name: int(entry.get("xla_flops", 0) or 0)
            for name, entry in raw.items()}


def build_flops_table(jaxbench_root: pathlib.Path, benches: Iterable[str]) -> dict[str, float]:
    """Return {benchmark_name: flops}.

    Precedence per benchmark:
      1. Analytical `flops = ...` expression embedded in baseline.py (matches
         the operator's textbook FLOP count; hand-authored for the 17 priority
         kernels).
      2. XLA cost_analysis() FLOPs from flops_cache.json (covers the 33
         KernelBench L2 fused-op workloads, which do not publish an analytical
         formula).
    """
    table: dict[str, float] = {}
    bench_root = jaxbench_root / "benchmark"
    xla_cache = _load_xla_flops_cache(jaxbench_root)
    for bench in benches:
        bench_dir = bench_root / bench
        flops = load_flops_for_bench(bench_dir)
        if flops is None:
            xla = xla_cache.get(bench, 0)
            if xla > 0:
                flops = float(xla)
        if flops is not None:
            table[bench] = flops
        else:
            print(f"warning: no FLOPs for {bench} (looked in {bench_dir})")
    return table


# ---------------------------------------------------------------------------
# MFU computation
# ---------------------------------------------------------------------------

def compute_mfu(latency_ms: float, flops: float, peak_tflops: float) -> float:
    """Return MFU as a fraction (0-1)."""
    if latency_ms is None or latency_ms <= 0:
        return 0.0
    return flops / (latency_ms * 1e-3 * peak_tflops * 1e12)


def compute_mfu_table(
    data: dict,
    flops_by_bench: dict[str, float],
    peak_tflops: float,
    methods: list[str],
) -> tuple[dict[str, float], dict[str, dict[str, float]], dict[str, dict[str, float]]]:
    """Return (xla_mfu_by_bench, mfu_solved_only, mfu_with_xla_floor).

    - `mfu_solved_only[m][b]` is set only if method `m` solved benchmark `b`.
    - `mfu_with_xla_floor[m][b]` falls back to the XLA baseline MFU when
      method `m` failed to solve `b`. Used for the distribution plot so
      failures don't get plotted at 0% and artificially depress the mean.
    """
    baselines = data["baselines"]
    per_bench_speedups = data["per_bench_speedups"]

    xla_mfu: dict[str, float] = {}
    for bench, baseline_ms in baselines.items():
        flops = flops_by_bench.get(bench)
        if flops is None:
            continue
        xla_mfu[bench] = compute_mfu(baseline_ms, flops, peak_tflops)

    mfu_solved: dict[str, dict[str, float]] = {m: {} for m in methods}
    mfu_floored: dict[str, dict[str, float]] = {m: {} for m in methods}
    for bench, baseline_ms in baselines.items():
        flops = flops_by_bench.get(bench)
        if flops is None:
            continue
        for m in methods:
            speedup = per_bench_speedups.get(m, {}).get(bench)
            if speedup and speedup > 0:
                best_lat_ms = baseline_ms / speedup
                mfu = compute_mfu(best_lat_ms, flops, peak_tflops)
                mfu_solved[m][bench] = mfu
                mfu_floored[m][bench] = mfu
            else:
                mfu_floored[m][bench] = xla_mfu.get(bench, 0.0)

    return xla_mfu, mfu_solved, mfu_floored


# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------

def plot_per_benchmark(
    xla_mfu: dict[str, float],
    mfu_by_method: dict[str, dict[str, float]],
    methods: list[str],
    out_path: pathlib.Path,
    sort_key: dict[str, float] | None = None,
) -> None:
    """Horizontal dot plot: one row per benchmark, one dot per method, + XLA tick.

    Matches the style of `trajectory.py:per_benchmark.pdf` (same sizing,
    offsets, label formatting, row guidelines) so the two figures compose
    cleanly in the paper.
    """
    # Sort by `sort_key` if given (typically the floored Autocomp MFU so
    # unsolved benchmarks land at their XLA value), else by XLA.
    sort_key = sort_key if sort_key is not None else xla_mfu
    benches = sorted(xla_mfu.keys(), key=lambda b: sort_key.get(b, 0.0))

    n = len(benches)
    fig_h = max(2.0, 0.15 * n + 0.8)
    fig, ax = plt.subplots(figsize=(5.5, fig_h))

    # Matches trajectory.py offsets for 3 methods; for 4, spread evenly.
    default_offsets = {"Best-of-N": 0.08, "Iterative": 0.0, "Autocomp": -0.08}
    if set(methods) == set(default_offsets) and len(methods) == 3:
        offsets = default_offsets
    else:
        spread = 0.24
        offsets = {m: spread * (0.5 - i / max(1, len(methods) - 1))
                   for i, m in enumerate(methods)}

    # XLA tick per row (underneath the method dots).
    xla_xs = [xla_mfu.get(b, 0.0) * 100 for b in benches]
    ax.scatter(xla_xs, list(range(n)),
               color=METHOD_COLORS["XLA"], marker=METHOD_MARKERS["XLA"],
               s=36, label="XLA baseline", zorder=2, linewidths=1.2)

    for m in methods:
        xs, ys = [], []
        for i, bench in enumerate(benches):
            v = mfu_by_method.get(m, {}).get(bench)
            if v is not None:
                xs.append(v * 100)
                ys.append(i + offsets[m])
        ax.scatter(xs, ys,
                   color=METHOD_COLORS[m], marker=METHOD_MARKERS[m],
                   s=36, label=m, zorder=3,
                   edgecolors="white", linewidths=0.5, alpha=0.9)

    ax.axvline(100, color="gray", linestyle="--", linewidth=0.8, alpha=0.6)
    ax.set_yticks(range(n))
    ax.set_ylim(-0.5, n - 0.5)
    bench_labels = [b.split("_", 1)[1].replace("_", " ") if "_" in b else b
                    for b in benches]
    ax.set_yticklabels(bench_labels, fontsize=7)
    ax.set_xlabel("MXU utilization (\\% of peak bf16 TFLOPS)")

    # Leave a small margin on both sides so dots at 0% and 100% aren't clipped.
    ax.set_xlim(-2, 102)

    ax.legend(loc="lower right", frameon=False, fontsize=8)
    ax.grid(True, axis="x", alpha=0.3, linewidth=0.5)
    for i in range(n):
        ax.axhline(i, color="gray", linewidth=0.3, alpha=0.4, zorder=0)
    fig.tight_layout()
    fig.savefig(out_path, bbox_inches="tight")
    plt.close(fig)
    print(f"Wrote {out_path}")


def plot_distribution(
    xla_mfu: dict[str, float],
    mfu_by_method: dict[str, dict[str, float]],
    methods: list[str],
    out_path: pathlib.Path,
) -> None:
    """Compact MFU-distribution strip plot: one column per (XLA + each method)."""
    series = [("XLA", list(xla_mfu.values()))]
    for m in methods:
        series.append((m, list(mfu_by_method[m].values())))

    fig, ax = plt.subplots(figsize=(4.2, 2.6))
    rng = np.random.default_rng(0)

    for i, (label, values) in enumerate(series):
        if not values:
            continue
        xs = i + (rng.random(len(values)) - 0.5) * 0.28
        ax.plot(
            xs,
            [v * 100 for v in values],
            marker=METHOD_MARKERS[label],
            markersize=5,
            markerfacecolor=METHOD_COLORS[label],
            markeredgecolor=METHOD_COLORS[label],
            linestyle="None",
            alpha=0.85,
        )
        mean = sum(values) / len(values) * 100
        ax.plot(
            [i - 0.22, i + 0.22], [mean, mean],
            color=METHOD_COLORS[label], linewidth=2, solid_capstyle="butt",
        )

    ax.set_xticks(range(len(series)))
    ax.set_xticklabels([lbl for lbl, _ in series], fontsize=9,
                       rotation=20, ha="right", rotation_mode="anchor")
    ax.set_ylabel("MXU utilization (%)")
    ax.set_ylim(0, 100)
    ax.axhline(100, color="black", linestyle="--", linewidth=0.8, alpha=0.4)
    ax.grid(axis="y", alpha=0.3)
    ax.set_axisbelow(True)
    fig.tight_layout()
    fig.savefig(out_path, bbox_inches="tight")
    plt.close(fig)
    print(f"Wrote {out_path}")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--out", type=pathlib.Path, required=True,
                    help="Figures directory; must contain trajectory_data.json")
    ap.add_argument("--jaxbench-root", type=pathlib.Path, default=DEFAULT_JAXBENCH_ROOT)
    ap.add_argument("--peak-tflops", type=float, default=DEFAULT_PEAK_TFLOPS)
    args = ap.parse_args()

    data_path = args.out / "trajectory_data.json"
    data = json.loads(data_path.read_text())

    methods = [m for m in ["Best-of-N", "Iterative", "Iterative+ctx", "Autocomp"]
               if m in data.get("per_bench_speedups", {})]
    benches = list(data.get("baselines", {}).keys())

    flops_by_bench = build_flops_table(args.jaxbench_root, benches)
    if not flops_by_bench:
        raise SystemExit(
            f"No FLOPs resolved; check --jaxbench-root ({args.jaxbench_root}) "
            f"or the benchmark names in {data_path}."
        )

    xla_mfu, mfu_solved, mfu_floored = compute_mfu_table(
        data, flops_by_bench, args.peak_tflops, methods,
    )

    print("Per-benchmark MFU (percent, solved only; `-` = failed):")
    print(f"  {'benchmark':<28s} {'XLA':>6s}  " + "  ".join(f"{m:>10s}" for m in methods))
    for bench in sorted(xla_mfu, key=lambda b: xla_mfu[b]):
        row = [f"{xla_mfu[bench]*100:>6.1f}"]
        for m in methods:
            v = mfu_solved[m].get(bench)
            row.append(f"{v*100:>10.1f}" if v is not None else f"{'-':>10s}")
        print(f"  {bench:<28s} " + "  ".join(row))

    def _mean(d: dict[str, float]) -> float:
        return sum(d.values()) / len(d) if d else float("nan")

    print("\nMean MFU (percent, with XLA floor for failures; used in fig:mfu-dist):")
    print(f"  XLA:       {_mean(xla_mfu)*100:5.1f}")
    for m in methods:
        print(f"  {m:<10s} {_mean(mfu_floored[m])*100:5.1f}")

    print("\nMean MFU (percent, over benchmarks each method solved):")
    for m in methods:
        if mfu_solved[m]:
            print(f"  {m:<10s} {_mean(mfu_solved[m])*100:5.1f}  (n={len(mfu_solved[m])})")
        else:
            print(f"  {m:<10s} (no solves)")

    plot_per_benchmark(
        xla_mfu, mfu_solved, methods, args.out / "mfu.pdf",
        sort_key=mfu_floored.get("Autocomp", xla_mfu),
    )
    plot_distribution(xla_mfu, mfu_floored, methods, args.out / "mfu_dist.pdf")


if __name__ == "__main__":
    main()
