from __future__ import annotations

"""
Plot 2 (M3 compiled vs M3 Triton) using *_big.json files.

Inputs (defaults follow outputs/fast_times/results naming):
  --compiled  path to compiled_sampling_big.json
  --triton    path to triton_sampling_big.json

Generates the second plot from plot_requested.py but only requires the two
sampling JSONs.
"""

import argparse
import json
import os
from pathlib import Path
from typing import Dict, List

import pandas as pd

try:
    from scripts.fast_times.plot_results import plot_all_methods
except ModuleNotFoundError:
    import sys
    from pathlib import Path as _Path
    sys.path.append(str(_Path(__file__).resolve().parents[2]))
    from scripts.fast_times.plot_results import plot_all_methods


def _load_methods(path: str) -> Dict[str, pd.DataFrame]:
    with open(path, "r") as f:
        obj = json.load(f)
    methods = obj.get("methods", {})
    out: Dict[str, pd.DataFrame] = {}
    for label, d in methods.items():
        out[label] = pd.DataFrame({
            "Nc": d.get("Nc", []),
            "num_samples": d.get("num_samples", []),
            "mean_time": d.get("mean_time", []),
            "std_time": d.get("std_time", []),
        })
    return out


def _save(fig, out_path: str) -> None:
    Path(os.path.dirname(out_path)).mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight")
    print(f"Saved plot: {out_path}")


def main() -> None:
    ap = argparse.ArgumentParser(description="Plot 2 (M3 compiled vs Triton) from *_big.json files")
    ap.add_argument("--compiled", type=str, default="outputs/fast_times/results/compiled_sampling_big.json")
    ap.add_argument("--triton", type=str, default="outputs/fast_times/results/triton_sampling_big.json")
    ap.add_argument("--out", type=str, default="outputs/fast_times/plots/plot2_sampling_m3_compare_big.png")
    ap.add_argument("--yscale", type=str, default="log", choices=["log", "log2", "linear"])
    args = ap.parse_args()

    comp = _load_methods(args.compiled)
    trit = _load_methods(args.triton)

    order = [
        (comp, "M3 Ours AR Buffer", "Ours (w/o Triton)", 4),
        (trit, "TR M3 Ours AR Buffer", "Ours (w/ Triton)", 5),
    ]
    data: List[pd.DataFrame] = []
    labels: List[str] = []
    colors: List[int] = []
    for src, key, label, color_idx in order:
        if key not in src:
            raise KeyError(f"Missing method '{key}' in {args.compiled if src is comp else args.triton}")
        data.append(src[key])
        labels.append(label)
        colors.append(color_idx)

    fig = plot_all_methods(
        data,
        method_labels=labels,
        color_indices=colors,
        grid_shape=(1, 4),
        figsize=(22, 8),
        suptitle=r"M3 (compiled vs Triton), $N_t=16$",
        suptitle_size=30.0,
        suptitle_y=0.84,
        preferred_font="TeX Gyre Termes",
        yscale=args.yscale,
    )
    _save(fig, args.out)


if __name__ == "__main__":
    main()
