from __future__ import annotations

"""
Plot 1 (sampling mixed) using *_big.json files.

Inputs (defaults follow outputs/fast_times/results naming):
  --compiled  path to compiled_sampling_big.json
  --triton    path to triton_sampling_big.json

Generates the first plot from plot_requested.py, but only requires the two
sampling JSONs (no baseline or LL files).
"""

import argparse
import json
import os
from pathlib import Path
from typing import Dict, List

import pandas as pd

try:
    from scripts.fast_times.plot_results import plot_all_methods
except ModuleNotFoundError:
    import sys
    from pathlib import Path as _Path
    sys.path.append(str(_Path(__file__).resolve().parents[2]))
    from scripts.fast_times.plot_results import plot_all_methods


def _load_methods(path: str) -> Dict[str, pd.DataFrame]:
    with open(path, "r") as f:
        obj = json.load(f)
    methods = obj.get("methods", {})
    out: Dict[str, pd.DataFrame] = {}
    for label, d in methods.items():
        out[label] = pd.DataFrame({
            "Nc": d.get("Nc", []),
            "num_samples": d.get("num_samples", []),
            "mean_time": d.get("mean_time", []),
            "std_time": d.get("std_time", []),
        })
    return out


def _save(fig, out_path: str) -> None:
    Path(os.path.dirname(out_path)).mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight")
    print(f"Saved plot: {out_path}")


def main() -> None:
    ap = argparse.ArgumentParser(description="Plot 1 (sampling mixed) from *_big.json files")
    ap.add_argument("--compiled", type=str, default="outputs/fast_times/results/compiled_sampling_big.json")
    ap.add_argument("--triton", type=str, default="outputs/fast_times/results/triton_sampling_big.json")
    ap.add_argument("--out", type=str, default="outputs/fast_times/plots/plot1_sampling_mixed_big.png")
    ap.add_argument("--yscale", type=str, default="log", choices=["log", "log2", "linear"])
    args = ap.parse_args()

    comp = _load_methods(args.compiled)
    trit = _load_methods(args.triton)

    order = [
        (comp, "M1 TNP-D Indep", "TNPD-Ind", 1),
        (comp, "M2 TNP-D AR", "TNPD-AR", 3),
        (trit, "TR M3 Ours AR Buffer", "Ours", 5),
        (comp, "M4 TNPA AR", "TNP-A", 7),
        (comp, "M5 TNP-ND", "TNP-ND", 9),
    ]
    data: List[pd.DataFrame] = []
    labels: List[str] = []
    colors: List[int] = []
    for src, key, label, color_idx in order:
        if key not in src:
            raise KeyError(f"Missing method '{key}' in {args.compiled if src is comp else args.triton}")
        data.append(src[key])
        labels.append(label)
        colors.append(color_idx)

    fig = plot_all_methods(
        data,
        method_labels=labels,
        color_indices=colors,
        grid_shape=(1, 4),
        figsize=(22, 8),
        suptitle=r"Sample generation time for $N_t=16$",
        suptitle_size=30.0,
        suptitle_y=0.84,
        preferred_font="TeX Gyre Termes",
        yscale=args.yscale,
    )
    _save(fig, args.out)


if __name__ == "__main__":
    main()
