from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Dict, List, Tuple
import math
import matplotlib.pyplot as plt
import matplotlib as mpl

import pandas as pd

# Robust import whether run as a module or a script
try:
    from scripts.fast_times.plot_results import plot_all_methods
except ModuleNotFoundError:
    # Ensure project root is on sys.path when run as a plain script
    import sys
    from pathlib import Path as _Path
    sys.path.append(str(_Path(__file__).resolve().parents[2]))
    from scripts.fast_times.plot_results import plot_all_methods


def _load_methods(path: str) -> Dict[str, pd.DataFrame]:
    with open(path, "r") as f:
        obj = json.load(f)
    methods = obj.get("methods", {})
    out: Dict[str, pd.DataFrame] = {}
    for label, d in methods.items():
        out[label] = pd.DataFrame({
            "Nc": d.get("Nc", []),
            "num_samples": d.get("num_samples", []),
            "mean_time": d.get("mean_time", []),
            "std_time": d.get("std_time", []),
        })
    return out


def _save(fig, out_path: str) -> None:
    # Paper-ready vector outputs
    mpl.rcParams["pdf.fonttype"] = 42
    mpl.rcParams["ps.fonttype"] = 42
    Path(os.path.dirname(out_path)).mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight")
    print(f"Saved plot: {out_path}")
    # If targeting paper_plots, also save a PDF twin
    try:
        if "paper_plots" in os.path.normpath(out_path):
            out_pdf = str(Path(out_path).with_suffix('.pdf'))
            fig.savefig(out_pdf, bbox_inches="tight", format="pdf")
            print(f"Saved plot: {out_pdf}")
    except Exception:
        pass


def _apply_legend(fig, axes, order: List[str]) -> None:
    # Build a figure-level legend using a canonical order
    legend_map = {}
    for ax in axes:
        h, l = ax.get_legend_handles_labels()
        for handle, label in zip(h, l):
            if label and label not in legend_map:
                legend_map[label] = handle
    labels_uniq = [lab for lab in order if lab in legend_map]
    handles_uniq = [legend_map[k] for k in labels_uniq]
    if handles_uniq:
        ncol = min(5, len(handles_uniq))
        fig.legend(handles_uniq, labels_uniq, loc='lower center', bbox_to_anchor=(0.5, 0.03), ncol=ncol, frameon=False)


def main() -> None:
    ap = argparse.ArgumentParser(description="Generate five requested comparisons from results directory.")
    ap.add_argument("--results_dir", type=str, default="outputs/fast_times/results",
                    help="Directory containing result JSONs: baseline/compiled/triton for sampling and LL")
    ap.add_argument("--out_dir", type=str, default="outputs/fast_times/plots",
                    help="Directory to write output plots")
    ap.add_argument("--linear-y", action="store_true", help="Use linear y-axis instead of log for timings (deprecated; use --yscale)")
    ap.add_argument("--yscale", type=str, default="log", choices=["log", "log2", "linear"], help="Y-axis scale for timings")
    args = ap.parse_args()

    rd = args.results_dir
    # Expected files (names used by our runners)
    sampling_comp = os.path.join(rd, "compiled_sampling.json")
    sampling_base = os.path.join(rd, "baseline_sampling.json")
    sampling_trit = os.path.join(rd, "triton_sampling.json")
    ll_comp = os.path.join(rd, "compiled_ll.json")
    ll_base = os.path.join(rd, "baseline_ll.json")
    # Optional triton LL (not always present)
    ll_trit = os.path.join(rd, "triton_ll.json")

    if not all(os.path.exists(p) for p in [sampling_comp, sampling_base, sampling_trit, ll_comp, ll_base]):
        missing = [p for p in [sampling_comp, sampling_base, sampling_trit, ll_comp, ll_base] if not os.path.exists(p)]
        raise FileNotFoundError(f"Missing expected result files: {missing}")

    comp_s = _load_methods(sampling_comp)
    base_s = _load_methods(sampling_base)
    trit_s = _load_methods(sampling_trit)
    comp_ll = _load_methods(ll_comp)
    base_ll = _load_methods(ll_base)
    trit_ll = _load_methods(ll_trit) if os.path.exists(ll_trit) else {}

    # 1) Sampling: M1,M2 from compiled; M3 (triton); M4,M5 from compiled
    # Consistent color mapping across plots
    # Use seaborn colorblind palette indices consistent with training plots (0..4, reserve 2)
    color_map = {
        "Ours": 0,
        "Ours (w/o Triton)": 0,
        "Ours (w/ Triton)": 0,
        "TNP-D-Ind": 1,
        "TNP-D-AR": 2,  # reserved green used here for AR to match other figures
        "TNP-A": 4,
        "TNP-ND": 3,
    }
    # Baselines first, Ours last
    order1_raw = [
        (comp_s, "M1 TNP-D Indep", "TNP-D-Ind"),
        (comp_s, "M2 TNP-D AR", "TNP-D-AR"),
        (comp_s, "M4 TNPA AR", "TNP-A"),
        (comp_s, "M5 TNP-ND", "TNP-ND"),
        (trit_s, "TR M3 Ours AR Buffer", "Ours"),
    ]
    data1: List[pd.DataFrame] = []
    labels1: List[str] = []
    colors1: List[int] = []
    for src, key, label in order1_raw:
        if key not in src:
            raise KeyError(f"Missing method '{key}' in sampling sources for plot 1")
        data1.append(src[key])
        labels1.append(label)
        colors1.append(color_map[label])
    yscale = ("linear" if args.linear_y else args.yscale)

    fig1 = plot_all_methods(
        data1,
        method_labels=labels1,
        color_indices=colors1,
        grid_shape=(1, 4),
        figsize=(22, 8),
        suptitle=r"Sample generation time ($M=16$)",
        suptitle_size=30.0,
        suptitle_y=0.96,
        preferred_font="TeX Gyre Termes",
        yscale=yscale,
        legend='none',
        finalize_layout=False,
    )
    # Rename x-axis from N_c to N for consistency
    try:
        for ax in fig1.axes:
            ax.set_xlabel('$N$')
    except Exception:
        pass
    # Legend with Ours first, then baselines
    _apply_legend(fig1, fig1.axes, ["Ours", "TNP-D-Ind", "TNP-D-AR", "TNP-A", "TNP-ND"])
    fig1.tight_layout(rect=[0, 0.12, 1, 0.90])
    _save(fig1, os.path.join(args.out_dir, "plot1_sampling_mixed.png"))

    # 2) Sampling: M3 compiled vs M3 triton
    order2_raw = [
        (comp_s, "M3 Ours AR Buffer", "Ours (w/o Triton)"),
        (trit_s, "TR M3 Ours AR Buffer", "Ours (w/ Triton)"),
    ]
    data2, labels2, colors2 = [], [], []
    for src, key, label in order2_raw:
        if key not in src:
            raise KeyError(f"Missing method '{key}' for plot 2")
        data2.append(src[key])
        labels2.append(label)
        colors2.append(color_map[label])
    fig2 = plot_all_methods(
        data2,
        method_labels=labels2,
        color_indices=colors2,
        grid_shape=(1, 4),
        figsize=(22, 8),
        # Title same as plot 1 to match style
        suptitle=r"Sample generation time ($M=16$)",
        suptitle_size=30.0,
        suptitle_y=0.96,
        preferred_font="TeX Gyre Termes",
        yscale=yscale,
        num_samples_values=[128, 256, 512, 1024],
        # Make non-Triton (compiled) visibly lighter than Triton
        method_alphas=[0.30, 1.00],
        legend='none',
        finalize_layout=False,
    )
    try:
        for ax in fig2.axes:
            ax.set_xlabel('$N$')
    except Exception:
        pass
    # Put Triton first for clarity
    _apply_legend(fig2, fig2.axes, ["Ours (w/ Triton)", "Ours (w/o Triton)"])
    fig2.tight_layout(rect=[0, 0.12, 1, 0.90])
    _save(fig2, os.path.join(args.out_dir, "plot2_sampling_m3_comp_vs_triton.png"))

    # 3) Sampling: baselines vs compiled sampling
    order3_raw = [
        (base_s, "TNPD-Independent", "TNP-D-Ind (base)"), (comp_s, "M1 TNP-D Indep", "TNP-D-Ind (compiled)"),
        (base_s, "TNPD-AR", "TNP-D-AR (base)"),          (comp_s, "M2 TNP-D AR", "TNP-D-AR (compiled)"),
        (base_s, "TNPA", "TNP-A (base)"),                (comp_s, "M4 TNPA AR", "TNP-A (compiled)"),
        (base_s, "TNP-ND", "TNP-ND (base)"),            (comp_s, "M5 TNP-ND", "TNP-ND (compiled)"),
    ]
    data3, labels3, colors3 = [], [], []
    for src, key, label in order3_raw:
        if key not in src:
            raise KeyError(f"Missing method '{key}' for plot 3")
        data3.append(src[key])
        labels3.append(label)
    # Explicit color order for plot 3 (TNPD-AR uses 2,3)
    # Use colorblind indices (repeat pattern; order is (base, compiled) pairs)
    colors3 = [1,1,2,2,4,4,3,3]
    fig3 = plot_all_methods(
        data3,
        method_labels=labels3,
        color_indices=colors3,
        grid_shape=(1, 4),
        figsize=(22, 8),
        suptitle=r"Sample generation time ($M=16$)",
        suptitle_size=30.0,
        suptitle_y=0.96,
        preferred_font="TeX Gyre Termes",
        yscale=yscale,
        # Pairwise: baseline lighter, compiled solid
        method_alphas=[0.30, 1.00, 0.30, 1.00, 0.30, 1.00, 0.30, 1.00],
        legend='none',
        finalize_layout=False,
    )
    try:
        for ax in fig3.axes:
            ax.set_xlabel('$N$')
    except Exception:
        pass
    # Pair-wise order per method: base, compiled
    _apply_legend(fig3, fig3.axes, [
        "TNP-D-Ind (base)", "TNP-D-Ind (compiled)",
        "TNP-D-AR (base)",  "TNP-D-AR (compiled)",
        "TNP-A (base)",     "TNP-A (compiled)",
        "TNP-ND (base)",    "TNP-ND (compiled)",
    ])
    fig3.tight_layout(rect=[0, 0.12, 1, 0.90])
    _save(fig3, os.path.join(args.out_dir, "plot3_sampling_baseline_vs_compiled.png"))

    # 4) LL: same as (1) but for LL
    # Prefer triton LL for M3 if present; else compiled LL
    m3_ll_key_src: Tuple[Dict[str, pd.DataFrame], str]
    if "TR LL M3 Ours" in trit_ll:
        m3_ll_key_src = (trit_ll, "TR LL M3 Ours")
    else:
        m3_ll_key_src = (comp_ll, "LL M3 Ours AR Buffer")
    # For mixed LL plot, label Ours without Triton qualifiers
    m3_label = "Ours"
    # Baselines first, Ours last
    order4_raw = [
        (m3_ll_key_src[0], m3_ll_key_src[1], m3_label),
        (comp_ll, "LL M1 TNP-D Indep", "TNP-D-Ind"),
        (comp_ll, "LL M2 TNP-D AR", "TNP-D-AR"),
        (comp_ll, "LL M4 TNPA AR", "TNP-A"),
        (comp_ll, "LL M5 TNP-ND", "TNP-ND"),
    ]
    data4, labels4, colors4 = [], [], []
    for src, key, label in order4_raw:
        if key not in src:
            raise KeyError(f"Missing method '{key}' for plot 4")
        data4.append(src[key])
        labels4.append(label)
        colors4.append(color_map[label])
    fig4 = plot_all_methods(
        data4,
        method_labels=labels4,
        color_indices=colors4,
        grid_shape=(1, 4),
        figsize=(22, 8),
        suptitle=r"Log-likelihood evaluation time ($M=16$)",
        suptitle_size=30.0,
        suptitle_y=0.96,
        preferred_font="TeX Gyre Termes",
        yscale=yscale,
        legend='none',
        finalize_layout=False,
    )
    try:
        for ax in fig4.axes:
            ax.set_xlabel('$N$')
    except Exception:
        pass
    _apply_legend(fig4, fig4.axes, ["Ours", "TNP-D-Ind", "TNP-D-AR", "TNP-A", "TNP-ND"])
    fig4.tight_layout(rect=[0, 0.12, 1, 0.90])
    _save(fig4, os.path.join(args.out_dir, "plot4_ll_mixed.png"))

    # 5) LL: baselines vs compiled LL
    order5_raw = [
        (base_ll, "TNPD-Independent", "TNPD-Ind (base)"), (comp_ll, "LL M1 TNP-D Indep", "TNPD-Ind (compiled)"),
        (base_ll, "TNPD-AR", "TNPD-AR (base)"),          (comp_ll, "LL M2 TNP-D AR", "TNPD-AR (compiled)"),
        (base_ll, "TNPA", "TNP-A (base)"),               (comp_ll, "LL M4 TNPA AR", "TNP-A (compiled)"),
        (base_ll, "TNP-ND", "TNP-ND (base)"),            (comp_ll, "LL M5 TNP-ND", "TNP-ND (compiled)"),
    ]
    data5, labels5, colors5 = [], [], []
    for src, key, label in order5_raw:
        if key not in src:
            raise KeyError(f"Missing method '{key}' for plot 5")
        data5.append(src[key])
        labels5.append(label)
    # Use same color for each (base, compiled) pair to emphasize variant shades
    colors5 = [1, 1, 2, 2, 4, 4, 3, 3]
    fig5 = plot_all_methods(
        data5,
        method_labels=labels5,
        color_indices=colors5,
        grid_shape=(1, 4),
        figsize=(22, 8),
        suptitle=r"Log-likelihood evaluation time ($M=16$)",
        suptitle_size=30.0,
        suptitle_y=0.96,
        preferred_font="TeX Gyre Termes",
        yscale=yscale,
        legend='none',
        finalize_layout=False,
        # Pairwise: baseline lighter, compiled solid
        method_alphas=[0.30, 1.00, 0.30, 1.00, 0.30, 1.00, 0.30, 1.00],
    )
    try:
        for ax in fig5.axes:
            ax.set_xlabel('$N$')
    except Exception:
        pass
    _apply_legend(fig5, fig5.axes, [
        "TNP-D-Ind (base)", "TNP-D-Ind (compiled)",
        "TNP-D-AR (base)",  "TNP-D-AR (compiled)",
        "TNP-A (base)",     "TNP-A (compiled)",
        "TNP-ND (base)",    "TNP-ND (compiled)",
    ])
    fig5.tight_layout(rect=[0, 0.12, 1, 0.90])
    _save(fig5, os.path.join(args.out_dir, "plot5_ll_baseline_vs_compiled.png"))

    # 6) Composite: first/last B for sampling (plots 1&2) and LL (plots 3&4) in a single 1x4 figure
    def _pick_two_b_values(dfs: List[pd.DataFrame], preferred: List[int] = [128, 1024]) -> List[int]:
        avail = sorted(set().union(*[set(df['num_samples'].unique()) for df in dfs]))
        picks = [v for v in preferred if v in avail]
        if len(picks) < 2 and len(avail) >= 2:
            picks = [avail[0], avail[-1]]
        elif len(picks) == 1 and len(avail) >= 2:
            # Add far end not equal to the single pick
            picks = [picks[0], avail[-1] if avail[-1] != picks[0] else avail[0]]
        elif len(picks) == 0 and len(avail) >= 1:
            picks = [avail[0]] * 2
        return picks[:2]

    # Filter method data for sampling group (plot 1 ordering)
    b_vals = _pick_two_b_values(data1)
    data1_small = [df[df['num_samples'].isin(b_vals)] for df in data1]
    # Filter for LL group (plot 4 ordering)
    data4_small = [df[df['num_samples'].isin(b_vals)] for df in data4]

    # Wider and slightly shorter; share y-axis across all four panels
    fig6, ax_grid = plt.subplots(1, 4, figsize=(24, 7), sharey=True, squeeze=False)
    ax_grid = ax_grid.flatten()

    # Left pair: sampling times for B in b_vals
    _ = plot_all_methods(
        data1_small,
        method_labels=labels1,
        color_indices=colors1,
        grid_shape=(1, 2),
        figsize=(22, 8),
        preferred_font="TeX Gyre Termes",
        fig=fig6,
        axes=ax_grid[:2],
        num_samples_values=b_vals,
        legend='none',
        finalize_layout=False,
        yscale=yscale,
    )

    # Right pair: LL times for B in b_vals
    _ = plot_all_methods(
        data4_small,
        method_labels=labels4,
        color_indices=colors4,
        grid_shape=(1, 2),
        figsize=(22, 8),
        preferred_font="TeX Gyre Termes",
        fig=fig6,
        axes=ax_grid[2:],
        num_samples_values=b_vals,
        legend='none',
        finalize_layout=False,
        yscale=yscale,
    )

    # Remove duplicated y-axis label on the third axis (keep only the leftmost one overall)
    try:
        ax_grid[2].set_ylabel("")
    except Exception:
        pass

    # Group titles: one above first two, one above last two (explicit Nt=16)
    fig6.text(0.25, 0.965, r"Sample generation time ($N_t=16$)", ha='center', va='bottom', fontsize=26, fontweight='bold')
    fig6.text(0.75, 0.965, r"Log-likelihood evaluation time ($N_t=16$)", ha='center', va='bottom', fontsize=26, fontweight='bold')

    # Combined legend across all panels (deduplicate by label) with canonical order
    legend_map = {}
    for ax in ax_grid:
        h, l = ax.get_legend_handles_labels()
        for handle, label in zip(h, l):
            if label and label not in legend_map:
                legend_map[label] = handle
    if legend_map:
        order = ["Ours", "TNP-D-Ind", "TNP-D-AR", "TNP-A", "TNP-ND"]
        labels_uniq = [lab for lab in order if lab in legend_map]
        handles_uniq = [legend_map[k] for k in labels_uniq]
        n_labels = len(labels_uniq)
        ncol = n_labels if n_labels <= 5 else 5
        fig6.legend(
            handles_uniq,
            labels_uniq,
            loc='lower center',
            bbox_to_anchor=(0.5, 0.03),
            ncol=ncol,
            frameon=False,
        )

    # Unify y-limits across all four panels (log-scale)
    try:
        ymins = [ax.get_ylim()[0] for ax in ax_grid]
        ymaxs = [ax.get_ylim()[1] for ax in ax_grid]
        ylo = min(ymins)
        yhi = max(ymaxs)
        for ax in ax_grid:
            ax.set_ylim(ylo, yhi)
    except Exception:
        pass

    # Tidy layout with space for group titles
    # Allow a bit more bottom space if multiple legend rows
    if legend_map:
        rows = math.ceil(len(legend_map) / (5 if len(legend_map) > 5 else len(legend_map)))
        bottom_margin = 0.06 + rows * 0.06
    else:
        bottom_margin = 0.06
    fig6.tight_layout(rect=[0, bottom_margin, 1, 0.88])
    _save(fig6, os.path.join(args.out_dir, f"plot6_two_groups_B_{b_vals[0]}_{b_vals[1]}.png"))


if __name__ == "__main__":
    main()
