from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib as mpl

# Robust import whether run as a module or a script
try:
    from scripts.fast_times.plot_results import plot_all_methods
except ModuleNotFoundError:
    import sys
    from pathlib import Path as _Path
    sys.path.append(str(_Path(__file__).resolve().parents[2]))
    from scripts.fast_times.plot_results import plot_all_methods


def _load_json(path: str) -> Dict:
    with open(path, "r") as f:
        return json.load(f)


def _load_methods(path: str) -> Dict[str, pd.DataFrame]:
    obj = _load_json(path)
    methods = obj.get("methods", {})
    out: Dict[str, pd.DataFrame] = {}
    for label, d in methods.items():
        out[label] = pd.DataFrame({
            "Nc": d.get("Nc", []),
            "num_samples": d.get("num_samples", []),
            "mean_time": d.get("mean_time", []),
            "std_time": d.get("std_time", []),
        })
    return out


def _prep_theme():
    # Paper-ready vector outputs
    mpl.rcParams["pdf.fonttype"] = 42
    mpl.rcParams["ps.fonttype"] = 42
    cb = sns.color_palette("colorblind", n_colors=10)
    sns.set_theme(style="white", font="serif", font_scale=2.0,
                  rc={
                      "axes.facecolor": "#ffffff",
                      "figure.facecolor": "#ffffff",
                      "axes.grid": False,
                      "grid.linestyle": "",
                  })
    return cb


def _combined_legend(fig, axes):
    legend_map = {}
    for ax in axes:
        h, l = ax.get_legend_handles_labels()
        for handle, label in zip(h, l):
            if label and label not in legend_map:
                legend_map[label] = handle
    if legend_map:
        labels_uniq = list(legend_map.keys())
        handles_uniq = [legend_map[k] for k in labels_uniq]
        ncol = min(5, len(labels_uniq))
        fig.legend(handles_uniq, labels_uniq, loc='lower center', bbox_to_anchor=(0.5, 0.03), ncol=ncol, frameon=False)


def _panel_sampling(fig, ax, sampling_comp: Dict[str, pd.DataFrame], sampling_trit: Dict[str, pd.DataFrame], *, yscale: str) -> None:
    # Consistent color mapping (colorblind palette indices)
    color_map = {
        "TNP-D-Ind": 1,
        "TNP-D-AR": 2,   # reserved green for AR
        "TNP-A": 4,
        "TNP-ND": 3,
        "Ours": 0,
    }
    # Baselines first, Ours last
    order = [
        (sampling_comp, "M1 TNP-D Indep", "TNP-D-Ind"),
        (sampling_comp, "M2 TNP-D AR", "TNP-D-AR"),
        (sampling_comp, "M4 TNPA AR", "TNP-A"),
        (sampling_comp, "M5 TNP-ND", "TNP-ND"),
        (sampling_trit, "TR M3 Ours AR Buffer", "Ours"),
    ]
    data: List[pd.DataFrame] = []
    labels: List[str] = []
    colors: List[int] = []
    for src, key, label in order:
        if key not in src:
            raise KeyError(f"Missing method '{key}' in sampling sources")
        data.append(src[key])
        labels.append(label)
        colors.append(color_map[label])

    # One panel at B=256
    plot_all_methods(
        data,
        method_labels=labels,
        color_indices=colors,
        grid_shape=(1, 1),
        figsize=(8, 5),
        fig=fig,
        axes=[ax],
        num_samples_values=[256],
        legend='none',
        finalize_layout=False,
        yscale=yscale,
        suptitle=None,
    )
    ax.set_title("Sample generation time")
    try:
        ax.set_xlabel('$N$')
    except Exception:
        pass


def _panel_ll(fig, ax, ll_comp: Dict[str, pd.DataFrame], ll_trit: Dict[str, pd.DataFrame], *, yscale: str) -> None:
    color_map = {
        "TNP-D-Ind": 1,
        "TNP-D-AR": 2,
        "TNP-A": 4,
        "TNP-ND": 3,
        "Ours": 0,
    }
    # Prefer triton LL for M3 if present
    if "TR LL M3 Ours" in ll_trit:
        ours_src, ours_key = ll_trit, "TR LL M3 Ours"
    else:
        ours_src, ours_key = ll_comp, "LL M3 Ours AR Buffer"
    # Baselines first, Ours last
    order = [
        (ll_comp, "LL M1 TNP-D Indep", "TNP-D-Ind"),
        (ll_comp, "LL M2 TNP-D AR", "TNP-D-AR"),
        (ll_comp, "LL M4 TNPA AR", "TNP-A"),
        (ll_comp, "LL M5 TNP-ND", "TNP-ND"),
        (ours_src, ours_key, "Ours"),
    ]
    data: List[pd.DataFrame] = []
    labels: List[str] = []
    colors: List[int] = []
    for src, key, label in order:
        if key not in src:
            raise KeyError(f"Missing method '{key}' in LL sources")
        data.append(src[key])
        labels.append(label)
        colors.append(color_map[label])

    plot_all_methods(
        data,
        method_labels=labels,
        color_indices=colors,
        grid_shape=(1, 1),
        figsize=(8, 5),
        fig=fig,
        axes=[ax],
        num_samples_values=[256],
        legend='none',
        finalize_layout=False,
        yscale=yscale,
        suptitle=None,
    )
    ax.set_title("Log-likelihood evaluation time")
    try:
        ax.set_xlabel('$N$')
    except Exception:
        pass


def _style_for_label(lbl: str) -> str | tuple:
    style_by_prefix = {
        'Ours': '-.', 'ACE': '-.',
        'TNPD': '-', 'TNPD-Ind': '-', 'TNPD-Independent': '-',
        'TNPA': ':',
        'TNP-ND': (0, (5, 2)),
    }
    for k, v in style_by_prefix.items():
        if lbl.startswith(k):
            return v
    return '-'


def _panel_training(ax, flex_json: str, baseline_json: str, *, yscale: str) -> None:
    # Use the same palette as other panels without resetting the theme
    cb = sns.color_palette("colorblind", n_colors=10)
    # Load flex (ours) and baseline fwd+bwd
    fobj = _load_json(flex_json)
    bobj = _load_json(baseline_json)

    # Flex: assume single method; canonicalize to 'Ours'
    fm = fobj.get("methods", {})
    if not fm:
        raise ValueError("No methods found in flex JSON")
    _, fd = next(iter(fm.items()))
    fdf = pd.DataFrame({
        'Nc': fd.get('Nc', []),
        'Nt': fd.get('Nt', []),
        'B': fd.get('B', []),
        'mean_time': fd.get('mean_time', []),
        'std_time': fd.get('std_time', []),
    })

    # Baselines: expect keys TNPD, TNPA, TNP-ND
    bm = bobj.get("methods", {})
    keep = []
    for name in ["TNPD", "TNPA", "TNP-ND"]:
        if name in bm:
            keep.append(name)
    if not keep:
        # Allow prefix matches as fallback
        for k in bm.keys():
            if any(k.startswith(p) for p in ("TNPD", "TNPA", "TNP-ND")):
                keep.append(k)

    # Filter to B=128 and Nt=256
    Bv, Nt = 128, 256
    series: List[Tuple[str, pd.DataFrame]] = [("Ours", fdf.copy())]
    for lbl in keep:
        d = bm[lbl]
        df = pd.DataFrame({
            'Nc': d.get('Nc', []),
            'Nt': d.get('Nt', []),
            'B': d.get('B', []),
            'mean_time': d.get('mean_time', []),
            'std_time': d.get('std_time', []),
        })
        series.append((lbl, df))

    # Colors: Ours=blue(0), TNPD=orange(1), TNPA=purple(4), TNP-ND=red(3)
    def _color(lbl: str):
        mapping = {"Ours": 0, "TNPD": 1, "TNPA": 4, "TNP-ND": 3}
        for pref, idx in mapping.items():
            if lbl.startswith(pref):
                return cb[idx]
        return cb[1]

    def _canon(lbl: str) -> str:
        if lbl.startswith("Ours") or lbl.startswith("ACE"):
            return "Ours"
        if lbl.startswith("TNPD-AR") or lbl.startswith("TNP-D-AR"):
            return "TNP-D-AR"
        if lbl.startswith("TNPD") or lbl.startswith("TNPD-Ind") or lbl.startswith("TNPD-Independent") or lbl.startswith("TNP-D-Ind"):
            return "TNP-D-Ind"
        if lbl.startswith("TNPA") or lbl.startswith("TNP-A"):
            return "TNP-A"
        if lbl.startswith("TNP-ND"):
            return "TNP-ND"
        return lbl

    # Plot lines for the selected (B,Nt)
    for idx, (lbl, df) in enumerate(series):
        # Filter and drop negative times
        mask = (df['B'].astype(int) == Bv) & (df['Nt'].astype(int) == Nt) & (np.array(df['mean_time']) >= 0)
        data = df[mask].copy()
        if data.empty:
            continue
        data.sort_values('Nc', inplace=True)
        x = data['Nc'].astype(int).values
        y = data['mean_time'].astype(float).values
        yerr = data['std_time'].astype(float).values if 'std_time' in data else None
        ax.errorbar(
            x, y, yerr=yerr,
            marker=['o', 's', '^', 'D', 'v', 'p'][idx % 6], markersize=12, linewidth=7.0, capsize=6,
            label=_canon(lbl), elinewidth=4.0, markeredgecolor='white', markeredgewidth=1.5,
            alpha=0.8, color=_color(lbl), linestyle=_style_for_label(lbl)
        )

    # Axes setup
    if yscale == "log":
        ax.set_yscale('log', base=10)
    elif yscale == "log2":
        ax.set_yscale('log', base=2)
    else:
        ax.set_yscale('linear')
    ax.set_xscale('log', base=2)
    x_ticks = [32, 64, 128, 256, 512, 1024]
    ax.set_xticks(x_ticks)
    ax.set_xticklabels([str(v) for v in x_ticks])
    ax.set_xlabel('$N$')
    ax.set_ylabel('Time (s)')
    ax.set_title("Training step time")

    for spine in ['left', 'bottom']:
        ax.spines[spine].set_linewidth(2)
        ax.spines[spine].set_color('black')
    for spine in ['right', 'top']:
        ax.spines[spine].set_visible(False)


def main() -> None:
    ap = argparse.ArgumentParser(description="Three-panel figure: B=256 sampling, B=256 LL, and training (B=128, M=256)")
    ap.add_argument("--results_dir", type=str, default="outputs/fast_times/results",
                    help="Directory containing sampling/LL JSON results (baseline/compiled/triton)")
    ap.add_argument("--flex_json", type=str, default="outputs/fast_times/results/flex_forward_backward.json",
                    help="Path to flex (Ours) forward+backward JSON")
    ap.add_argument("--baseline_fwb_json", type=str, default="outputs/fast_times/results/forward_backward_times.json",
                    help="Path to baseline forward+backward JSON (TNPD/TNPA/TNP-ND)")
    ap.add_argument("--out", type=str, default="outputs/fast_times/new_plots/three_panel_b256_training_m256.png")
    ap.add_argument("--figsize", type=str, default="20x7")
    ap.add_argument("--yscale", type=str, default="log", choices=["log", "log2", "linear"])
    args = ap.parse_args()

    # Theme and palette
    _prep_theme()

    # Load sampling and LL JSONs from results_dir
    comp_s = _load_methods(os.path.join(args.results_dir, "compiled_sampling.json"))
    trit_s = _load_methods(os.path.join(args.results_dir, "triton_sampling.json"))
    comp_ll = _load_methods(os.path.join(args.results_dir, "compiled_ll.json"))
    ll_trit_path = os.path.join(args.results_dir, "triton_ll.json")
    trit_ll = _load_methods(ll_trit_path) if os.path.exists(ll_trit_path) else {}

    # Create 3x1 figure
    W, H = (float(x) for x in args.figsize.lower().split("x"))
    fig, axes = plt.subplots(1, 3, figsize=(W, H), sharey=False)
    axes = np.array(axes).flatten().tolist()

    # Panel 1: Sampling @ B=256
    _panel_sampling(fig, axes[0], comp_s, trit_s, yscale=args.yscale)

    # Panel 2: Log-likelihood @ B=256
    _panel_ll(fig, axes[1], comp_ll, trit_ll, yscale=args.yscale)

    # Panel 3: Training (flex vs baselines) @ B=128, M=256
    _panel_training(axes[2], args.flex_json, args.baseline_fwb_json, yscale=args.yscale)

    # Training-only legend to align with training plots (canonical order), add proxies if missing
    handles, labels = axes[2].get_legend_handles_labels()
    # Enforce canonical order and fill missing ones with proxy handles
    order = ["Ours", "TNP-D-Ind", "TNP-D-AR", "TNP-A", "TNP-ND"]
    # Build present handle map
    hmap = {lab: h for h, lab in zip(handles, labels)}
    # Helper mappings for proxies
    cb = sns.color_palette("colorblind", n_colors=10)
    color_idx = {"Ours": 0, "TNP-D-Ind": 1, "TNP-D-AR": 2, "TNP-ND": 3, "TNP-A": 4}
    style_map = {"TNP-D-Ind": '-', "TNP-D-AR": '--', "TNP-A": ':', "TNP-ND": (0, (5, 2)), "Ours": '-.'}
    # Compose ordered handles/labels
    ordered_h, ordered_l = [], []
    for lab in order:
        if lab in hmap:
            ordered_h.append(hmap[lab])
            ordered_l.append(lab)
        else:
            # Create a proxy so the legend shows complete canonical set
            proxy = Line2D([0], [0], color=cb[color_idx[lab]], linestyle=style_map[lab], linewidth=7.0,
                           marker=['o','s','^','D','v','p'][len(ordered_h) % 6], markersize=12,
                           markeredgecolor='white', markeredgewidth=1.5)
            ordered_h.append(proxy)
            ordered_l.append(lab)
    fig.legend(ordered_h, ordered_l, loc='lower center', bbox_to_anchor=(0.5, 0.03), ncol=min(5, len(ordered_l)), frameon=False)
    fig.tight_layout(rect=[0, 0.08, 1, 0.98])

    # Save
    out_path = args.out
    Path(os.path.dirname(out_path)).mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight")
    # Always save a PDF twin alongside the requested output
    try:
        out_pdf = os.path.splitext(out_path)[0] + ".pdf"
        fig.savefig(out_pdf, bbox_inches="tight", format="pdf")
    except Exception:
        pass
    print(f"Saved plot: {out_path}")


if __name__ == "__main__":
    main()
