from __future__ import annotations

import json
import argparse
from typing import List, Dict, Any, Tuple

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mticker
import os


def _load_methods(json_path: str) -> Tuple[Dict[str, Dict[str, list]], Dict[str, Any]]:
    with open(json_path, "r") as f:
        obj = json.load(f)
    return obj.get("methods", {}), obj.get("metadata", {})


def _unique_sorted(vals: list) -> list:
    return sorted(list({int(x) for x in vals}))


def _prep_theme(preferred_font: str | None = None):
    # Paper-ready vector outputs
    mpl.rcParams["pdf.fonttype"] = 42
    mpl.rcParams["ps.fonttype"] = 42
    serif_stack = [
        "TeX Gyre Termes", "STIX Two Text", "Times New Roman", "Times",
        "Nimbus Roman", "Liberation Serif", "DejaVu Serif",
    ]
    if preferred_font:
        serif_stack = [preferred_font] + [f for f in serif_stack if f.lower() != preferred_font.lower()]

    sns.set_theme(
        style="white",
        font="serif", font_scale=2.2,
        rc={
            "axes.facecolor": "#ffffff",
            "figure.facecolor": "#ffffff",
            "grid.linestyle": "",
            "axes.grid": False,
            "font.family": "serif",
            "font.serif": serif_stack,
            "axes.titleweight": "bold",
        },
    )
    # Use seaborn's colorblind-friendly palette (first 5 colors; leave index 2 for TNPD-AR elsewhere)
    cb_palette = sns.color_palette("colorblind", n_colors=5)
    sns.set_palette(cb_palette)
    return cb_palette


def _style_for_label(label: str, idx: int):
    # Match plot_results.py style mapping and extend for our labels
    style_by_prefix = {
        'LL M1': '-', 'M1 ': '-', 'TNPD-Independent': '-', 'TNPD-Ind': '-', 'TNPD': '-', 'TNP-D-Ind': '-',
        'LL M2': '--', 'M2 ': '--', 'TNPD-AR': '--', 'TNP-D-AR': '--',
        'LL M3': '-.', 'M3 ': '-.', 'TR M3': '-.', 'Ours': '-.', 'OURS': '-.',
        'LL M4': ':',  'M4 ': ':',  'TNPA': ':',   'TNP-A': ':',
        'LL M5': (0, (5, 2)), 'M5 ': (0, (5, 2)), 'TNP-ND': (0, (5, 2)),
    }
    for k, v in style_by_prefix.items():
        if label.startswith(k):
            return v
    fallback_styles = ['-', '--', '-.', ':', (0, (5, 2))]
    return fallback_styles[idx % len(fallback_styles)]


def plot_fwd_bwd(json_path: str, out_dir: str,
                 grid_shape: Tuple[int, int] = (1, 4), figsize: Tuple[float, float] = (15.0, 11.0),
                 yscale: str = "log", preferred_font: str | None = None,
                 select_B: List[int] | None = None, select_Nt: List[int] | None = None,
                 color_indices: List[int] | None = None,
                 suptitle: str | None = None) -> List[str]:
    methods, meta = _load_methods(json_path)
    cb_palette = _prep_theme(preferred_font)

    # Fixed color mapping per method label (prefix match) to align with other figures
    # Uses the Paired palette indices for stability across runs
    # Explicit color mapping (avoid green; ensure Ours and TNPA are distinct)
    # Color indices (reserve 2 for TNPD-AR elsewhere). Use 0,1,3,4 here.
    COLOR_IDX_BY_PREFIX = {
        # Ours → blue (0)
        "Ours": 0, "OURS": 0,
        # TNP-D-Ind → orange (1)
        "TNP-D-Ind": 1, "TNPD": 1, "TNPD-Independent": 1, "TNPD-Ind": 1,
        # TNP-D-AR → green (2)
        "TNP-D-AR": 2, "TNPD-AR": 2,
        # TNP-ND → red (3)
        "TNP-ND": 3, "TNP ND": 3,
        # TNP-A → purple (4)
        "TNP-A": 4, "TNPA": 4,
    }

    def pick_color(label: str, fallback_idx: int):
        for prefix, ci in COLOR_IDX_BY_PREFIX.items():
            if label.startswith(prefix):
                return cb_palette[ci % len(cb_palette)]
        return cb_palette[fallback_idx % len(cb_palette)]

    # Collect uniques (Nc ticks exclude "Ours" so x-ticks reflect base Nc only)
    all_B = set()
    all_Nt = set()
    all_Nc = set()
    for label, d in methods.items():
        all_B.update(int(x) for x in d.get("B", []))
        all_Nt.update(int(x) for x in d.get("Nt", []))
        # Skip Ours for Nc ticks to avoid Nc+K clutter
        if not (label.startswith("Ours") or label.startswith("OURS")):
            all_Nc.update(int(x) for x in d.get("Nc", []))
    Bs = sorted(select_B if select_B else list(all_B))
    Nts = sorted(select_Nt if select_Nt else list(all_Nt))
    Ncs = sorted(list(all_Nc))
    if not Bs or not Nts or not Ncs:
        raise ValueError("No data found in JSON; ensure methods contain B, Nc, Nt arrays")

    os.makedirs(out_dir, exist_ok=True)
    out_paths: List[str] = []

    nrows, ncols = grid_shape
    markers = ['o', 's', '^', 'D', 'v', 'p']
    x_ticks = Ncs

    # Consistent legend order (omit TNP-D-AR from legend, still plotted)
    label_order = ["Ours", "TNP-D-Ind", "TNP-A", "TNP-ND"]

    def order_idx(lbl: str) -> int:
        for i, p in enumerate(label_order):
            if lbl.startswith(p):
                return i
        return len(label_order)

    for Bi, B in enumerate(Bs):
        # Create fig for this batch size
        fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharey=True, squeeze=False)
        axes = np.array(axes).flatten().tolist()
        # Match flex_vs_baselines format: no bold, explicit title
        fig.suptitle(f"Training step time (forward + backward, $B={B}$)", fontsize=26.0, y=0.97)

        # Choose Nt values to fill subplots
        sel_Nt = Nts[: (nrows * ncols)]
        for ax_idx, Nt in enumerate(sel_Nt):
            ax = axes[ax_idx]
            # For each method, filter rows for this B and Nt, drop negative times
            # Enforce consistent plotting order
            items = list(methods.items())
            # Canonicalize labels for sorting first
            def canon_name(x):
                nm = x[0]
                if nm.startswith("OURS") or nm.startswith("Ours"):
                    return "Ours"
                if nm.startswith("TNPD-AR"):
                    return "TNP-D-AR"
                if nm.startswith("TNPD") or nm.startswith("TNPD-Ind") or nm.startswith("TNPD-Independent"):
                    return "TNP-D-Ind"
                if nm.startswith("TNPA") or nm.startswith("TNP-A"):
                    return "TNP-A"
                if nm.startswith("TNP-ND"):
                    return "TNP-ND"
                return nm
            items.sort(key=lambda kv: order_idx(canon_name(kv)))
            for midx, (label_raw, d) in enumerate(items):
                # Canonicalize display label (rename anything starting with OURS to "Ours")
                if label_raw.startswith("OURS") or label_raw.startswith("Ours"):
                    label = "Ours"
                elif label_raw.startswith("TNPD-AR"):
                    label = "TNP-D-AR"
                elif label_raw.startswith("TNPD") or label_raw.startswith("TNPD-Ind") or label_raw.startswith("TNPD-Independent"):
                    label = "TNP-D-Ind"
                elif label_raw.startswith("TNPA") or label_raw.startswith("TNP-A"):
                    label = "TNP-A"
                elif label_raw.startswith("TNP-ND"):
                    label = "TNP-ND"
                else:
                    label = label_raw
                arr_B = np.asarray(d.get("B", []), dtype=int)
                arr_Nc = np.asarray(d.get("Nc", []), dtype=int)
                arr_Nt = np.asarray(d.get("Nt", []), dtype=int)
                arr_mean = np.asarray(d.get("mean_time", []), dtype=float)
                arr_std = np.asarray(d.get("std_time", []), dtype=float)
                m = (arr_B == B) & (arr_Nt == Nt) & (arr_mean >= 0)
                if not np.any(m):
                    continue
                # Sort by Nc
                Nc_vals = arr_Nc[m]
                idx = np.argsort(Nc_vals)
                x = Nc_vals[idx]
                y = arr_mean[m][idx]
                yerr = arr_std[m][idx]
                # Align Ours to original Nc by subtracting K (avoid clutter from Nc+K on x-axis)
                if label.startswith("Ours"):
                    ours_k = int(meta.get("ours_k", 0) or 0)
                    if ours_k:
                        x = x - ours_k
                        # Drop any non-positive x after shift
                        keep = x > 0
                        x, y, yerr = x[keep], y[keep], yerr[keep]
                        if x.size == 0:
                            continue
                ax.errorbar(
                    x, y, yerr=yerr,
                    marker=markers[midx % len(markers)], markersize=12, linewidth=7.0, capsize=6,
                    label=label, elinewidth=4.0, markeredgecolor='white', markeredgewidth=1.5,
                    alpha=0.8, color=pick_color(label, midx),
                    linestyle=_style_for_label(label, midx)
                )

            # Axis scales
            if yscale == "log":
                ax.set_yscale('log', base=10)
                ax.yaxis.set_major_locator(mticker.LogLocator(base=10.0, numticks=5))
                ax.yaxis.set_major_formatter(mticker.LogFormatterSciNotation(base=10.0))
            elif yscale == "log2":
                ax.set_yscale('log', base=2)
                ax.yaxis.set_major_locator(mticker.LogLocator(base=2.0, numticks=6))
                ax.yaxis.set_major_formatter(mticker.LogFormatter(base=2.0))
            else:
                ax.set_yscale('linear')
                ax.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5))

            ax.set_xscale('log', base=2)
            ax.set_xlabel('$N$')
            if ax_idx % ncols == 0:
                ax.set_ylabel('Time (s)')
            ax.set_title(f'$M = {Nt}$', fontweight='bold', pad=18)
            ax.set_xticks(x_ticks)
            ax.set_xticklabels([str(v) for v in x_ticks])
            # Clean spines
            for spine in ['left', 'bottom']:
                ax.spines[spine].set_linewidth(2)
                ax.spines[spine].set_color('black')
            for spine in ['right', 'top']:
                ax.spines[spine].set_visible(False)

        # Hide unused axes
        for j in range(len(sel_Nt), nrows * ncols):
            axes[j].axis('off')

        # Legend at bottom with canonical order; include proxies for missing labels
        handles, labels = axes[0].get_legend_handles_labels()
        if labels:
            from matplotlib.lines import Line2D
            order = label_order
            hmap = {lab: h for h, lab in zip(handles, labels)}
            cb = _prep_theme(preferred_font)
            color_idx = {"Ours": 0, "TNP-D-Ind": 1, "TNP-D-AR": 2, "TNP-ND": 3, "TNP-A": 4}
            style_map = {"TNP-D-Ind": '-', "TNP-D-AR": '--', "TNP-A": ':', "TNP-ND": (0, (5, 2)), "Ours": '-.'}
            ordered_h, ordered_l = [], []
            for lab in order:
                if lab in hmap:
                    ordered_h.append(hmap[lab])
                    ordered_l.append(lab)
                else:
                    proxy = Line2D([0], [0], color=cb[color_idx[lab]], linestyle=style_map[lab], linewidth=7.0,
                                   marker=['o','s','^','D','v','p'][len(ordered_h) % 6], markersize=12,
                                   markeredgecolor='white', markeredgewidth=1.5)
                    ordered_h.append(proxy)
                    ordered_l.append(lab)
            legend_ncol = min(5, len(ordered_l))
            fig.legend(ordered_h, ordered_l, loc='lower center', bbox_to_anchor=(0.5, 0.02),
                       ncol=legend_ncol, frameon=False)

        fig.tight_layout(rect=[0, 0.12, 1, 0.98])
        out_path = os.path.join(out_dir, f"fwd_bwd_B_{B}.png")
        fig.savefig(out_path, bbox_inches="tight")
        # Also write PDF for paper-ready assets if targeting paper_plots
        try:
            if "paper_plots" in os.path.normpath(out_dir):
                out_pdf = os.path.splitext(out_path)[0] + ".pdf"
                fig.savefig(out_pdf, bbox_inches="tight", format="pdf")
        except Exception:
            pass
        out_paths.append(out_path)
        plt.close(fig)

    return out_paths


def main() -> None:
    ap = argparse.ArgumentParser(description="Plot forward+backward JSON results (one figure per B; subplots by Nt)")
    ap.add_argument("json", type=str, help="Path to fwd_bwd_gpu.json (or similar)")
    ap.add_argument("--grid", type=str, default="1x4", help="Grid e.g. 1x4, 2x2")
    ap.add_argument("--figsize", type=str, default="15x11", help="Figure size WxH inches")
    ap.add_argument("--yscale", type=str, default="log", choices=["log", "log2", "linear"], help="Y-axis scale")
    ap.add_argument("--font", type=str, default="", help="Preferred serif font family name")
    ap.add_argument("--B", type=str, default="", help="Comma list of B to include (default: all)")
    ap.add_argument("--Nt", type=str, default="", help="Comma list of Nt to include (default: all)")
    ap.add_argument("--colors", type=str, default="", help="Comma list of palette indices (0-11)")
    ap.add_argument("--out_dir", type=str, default="outputs/fast_times/new_plots", help="Output directory for PNGs")
    ap.add_argument("--title", type=str, default="", help="Big title prefix for figures")
    args = ap.parse_args()

    g = tuple(int(x) for x in args.grid.lower().split("x"))
    w, h = (float(x) for x in args.figsize.lower().split("x"))
    Bs = [int(x) for x in args.B.split(",") if args.B] if args.B else None
    Nts = [int(x) for x in args.Nt.split(",") if args.Nt] if args.Nt else None
    color_indices = [int(x) for x in args.colors.split(",")] if args.colors else None
    suptitle = args.title or None

    out_paths = plot_fwd_bwd(
        args.json, args.out_dir, grid_shape=g, figsize=(w, h), yscale=args.yscale,
        preferred_font=(args.font or None), select_B=Bs, select_Nt=Nts, color_indices=color_indices,
        suptitle=suptitle
    )
    print("Saved:")
    for p in out_paths:
        print(" -", p)


if __name__ == "__main__":
    main()
