from __future__ import annotations

import json
import argparse
from typing import Dict, Any, List, Tuple, Set

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mticker


def _load_json(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return json.load(f)


def _to_triples(nc: List[int], nt: List[int], b: List[int], mean: List[float]) -> Set[Tuple[int, int, int]]:
    triples: Set[Tuple[int, int, int]] = set()
    for Nc, Nt, Bv, m in zip(nc, nt, b, mean):
        if m is None:
            continue
        try:
            if float(m) >= 0:
                triples.add((int(Nc), int(Nt), int(Bv)))
        except Exception:
            continue
    return triples


def _prep_theme():
    # Paper-ready vector outputs
    mpl.rcParams["pdf.fonttype"] = 42
    mpl.rcParams["ps.fonttype"] = 42
    # Use seaborn's colorblind-friendly palette (first 5 colors; leave index 2 for TNPD-AR elsewhere)
    cb_palette = sns.color_palette("colorblind", n_colors=5)
    sns.set_theme(
        style="white",
        font="serif", font_scale=2.2,
        rc={
            "axes.facecolor": "#ffffff",
            "figure.facecolor": "#ffffff",
            "grid.linestyle": "",
            "axes.grid": False,
            "axes.titleweight": "bold",
        },
    )
    sns.set_palette(cb_palette)
    return cb_palette


def _style_for_label(label: str, idx: int):
    style_by_prefix = {
        'TNP-D-Ind': '-', 'TNPD': '-', 'TNPD-Ind': '-', 'TNPD-Independent': '-',
        'TNP-D-AR': '--', 'TNPD-AR': '--',
        'TNPA': ':', 'TNP-A': ':',
        'TNP-ND': (0, (5, 2)),
        'Ours': '-.', 'ACE': '-.',
    }
    for k, v in style_by_prefix.items():
        if label.startswith(k):
            return v
    return ['-', '--', '-.', ':', (0, (5, 2))][idx % 5]


def _color_for_label(label: str, palette, fallback_idx: int):
    # Explicit mapping (seaborn colorblind):
    # Ours=0 (blue), TNP-D-Ind=1 (orange), TNP-D-AR=2 (green), TNP-ND=3 (red), TNP-A=4 (purple)
    mapping = {
        'Ours': 0, 'ACE': 0,
        'TNP-D-Ind': 1, 'TNPD': 1, 'TNPD-Ind': 1, 'TNPD-Independent': 1,
        'TNP-D-AR': 2, 'TNPD-AR': 2,
        'TNP-ND': 3,
        'TNP-A': 4, 'TNPA': 4,
    }
    for prefix, idx in mapping.items():
        if label.startswith(prefix):
            return palette[idx % len(palette)]
    return palette[fallback_idx % len(palette)]


def _canon_label(lbl: str) -> str:
    if lbl.startswith('Ours') or lbl.startswith('ACE'):
        return 'Ours'
    if lbl.startswith('TNPD-AR') or lbl.startswith('TNP-D-AR'):
        return 'TNP-D-AR'
    if lbl.startswith('TNPD') or lbl.startswith('TNPD-Ind') or lbl.startswith('TNPD-Independent') or lbl.startswith('TNP-D-Ind'):
        return 'TNP-D-Ind'
    if lbl.startswith('TNPA') or lbl.startswith('TNP-A'):
        return 'TNP-A'
    if lbl.startswith('TNP-ND'):
        return 'TNP-ND'
    return lbl


def plot_flex_vs_baselines(flex_json: str, baseline_json: str, *, grid: Tuple[int, int], figsize: Tuple[float, float], out_dir: str,
                           yscale: str = "log") -> List[str]:
    flex_obj = _load_json(flex_json)
    base_obj = _load_json(baseline_json)

    # Extract flex data (assume single method with Nc, Nt, B, mean_time, std_time)
    flex_methods = flex_obj.get("methods", {})
    if not flex_methods:
        raise ValueError("No methods in flex_forward_backward JSON")
    flex_label, flex_d = next(iter(flex_methods.items()))
    # Canonicalize label to "Ours"
    ours_label = "Ours"
    flex_df = pd.DataFrame({
        'Nc': flex_d.get('Nc', []),
        'Nt': flex_d.get('Nt', []),
        'B': flex_d.get('B', []),
        'mean_time': flex_d.get('mean_time', []),
        'std_time': flex_d.get('std_time', []),
    })
    flex_df = flex_df[(np.array(flex_df['mean_time']) >= 0)]
    triples = _to_triples(flex_df['Nc'].tolist(), flex_df['Nt'].tolist(), flex_df['B'].tolist(), flex_df['mean_time'].tolist())
    if not triples:
        raise ValueError("No valid (Nc,Nt,B) triples in flex JSON")

    # Extract baseline methods (TNP-D-Ind/TNPD, TNP-D-AR, TNP-A/TNPA, TNP-ND)
    base_methods = base_obj.get("methods", {})
    keep_labels: List[str] = []
    # Try exact names first
    for lbl in ["TNPD", "TNPD-AR", "TNPA", "TNP-ND"]:
        if lbl in base_methods:
            keep_labels.append(lbl)
    # Fallback: prefix match
    if not keep_labels:
        for lbl in base_methods.keys():
            if any(lbl.startswith(p) for p in ("TNPD-AR", "TNPD", "TNPA", "TNP-ND")):
                keep_labels.append(lbl)
    if not keep_labels:
        raise ValueError("Could not find TNPD/TNPA/TNP-ND in baseline JSON")

    # Build filtered dataframes for baselines and ours
    series: List[Tuple[str, pd.DataFrame]] = []
    # Build series and enforce consistent order
    series.append((ours_label, flex_df.copy()))
    for lbl in keep_labels:
        d = base_methods[lbl]
        df = pd.DataFrame({
            'Nc': d.get('Nc', []),
            'Nt': d.get('Nt', []),
            'B': d.get('B', []),
            'mean_time': d.get('mean_time', []),
            'std_time': d.get('std_time', []),
        })
        # Filter to intersection with flex triples and drop negative times
        mask = [((int(r.Nc), int(r.Nt), int(r.B)) in triples) and (float(r.mean_time) >= 0) for r in df.itertuples(index=False)]
        df_f = df[mask]
        if not df_f.empty:
            series.append((lbl, df_f))

    # Legend order without TNP-D-AR per request (still plotted if present)
    label_order = ["Ours", "TNP-D-Ind", "TNP-A", "TNP-ND"]
    def order_idx(lbl: str) -> int:
        canon = _canon_label(lbl)
        for i, p in enumerate(label_order):
            if canon.startswith(p):
                return i
        return len(label_order)
    series.sort(key=lambda x: order_idx(x[0]))

    if len(series) <= 1:
        raise ValueError("No overlapping baseline data found for the flex grid")

    # Plot: one figure per B, subplots by Nt, x-axis Nc
    palette = _prep_theme()
    out_paths: List[str] = []
    all_B = sorted(set(int(b) for b in flex_df['B'].unique()))
    all_Nt = sorted(set(int(t) for t in flex_df['Nt'].unique()))
    x_ticks = sorted(set(int(n) for n in flex_df['Nc'].unique()))
    nrows, ncols = grid

    for Bv in all_B:
        fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharey=True, squeeze=False)
        axes = np.array(axes).flatten().tolist()
        fig.suptitle(f"Training step time (forward + backward, $B={Bv}$)", fontsize=26.0, y=0.97)

        sel_nts = all_Nt[: nrows * ncols]
        for ax_idx, Nt in enumerate(sel_nts):
            ax = axes[ax_idx]
            for idx, (lbl, df) in enumerate(series):
                # Filter by B and Nt and drop negative times
                data = df[(df['B'].astype(int) == Bv) & (df['Nt'].astype(int) == Nt) & (df['mean_time'] >= 0)].copy()
                if data.empty:
                    continue
                # Sort by Nc and plot
                data.sort_values('Nc', inplace=True)
                x = data['Nc'].astype(int).values
                y = data['mean_time'].astype(float).values
                yerr = data['std_time'].astype(float).values if 'std_time' in data else None
                ax.errorbar(
                    x, y, yerr=yerr,
                    marker=['o', 's', '^', 'D', 'v', 'p'][idx % 6], markersize=12, linewidth=7.0, capsize=6,
                    label=_canon_label(lbl), elinewidth=4.0, markeredgecolor='white', markeredgewidth=1.5,
                    alpha=0.8, color=_color_for_label(lbl, palette, idx),
                    linestyle=_style_for_label(lbl, idx)
                )

            # Axes setup
            if yscale == "log":
                ax.set_yscale('log', base=10)
                ax.yaxis.set_major_locator(mticker.LogLocator(base=10.0, numticks=5))
                ax.yaxis.set_major_formatter(mticker.LogFormatterSciNotation(base=10.0))
            elif yscale == "log2":
                ax.set_yscale('log', base=2)
                ax.yaxis.set_major_locator(mticker.LogLocator(base=2.0, numticks=6))
                ax.yaxis.set_major_formatter(mticker.LogFormatter(base=2.0))
            else:
                ax.set_yscale('linear')
                ax.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5))
            ax.set_xscale('log', base=2)
            ax.set_xticks(x_ticks)
            ax.set_xticklabels([str(v) for v in x_ticks])
            ax.set_xlabel('$N$')
            if ax_idx % ncols == 0:
                ax.set_ylabel('Time (s)')
            ax.set_title(f'$M = {Nt}$', fontweight='bold', pad=18)
            for spine in ['left', 'bottom']:
                ax.spines[spine].set_linewidth(2)
                ax.spines[spine].set_color('black')
            for spine in ['right', 'top']:
                ax.spines[spine].set_visible(False)

        # Hide unused axes
        for j in range(len(sel_nts), nrows * ncols):
            axes[j].axis('off')

        # Legend
        handles, labels = axes[0].get_legend_handles_labels()
        if labels:
            # Canonical order with Ours first, omit TNP-D-AR from legend
            order = ["Ours", "TNP-D-Ind", "TNP-A", "TNP-ND"]
            # Deduplicate and re-map handles by canonical label
            hmap = {}
            for h, lab in zip(handles, labels):
                hmap[_canon_label(lab)] = h
            # Add proxies for any missing labels so the legend is complete
            from matplotlib.lines import Line2D
            cb = _prep_theme()
            color_idx = {"Ours": 0, "TNP-D-Ind": 1, "TNP-D-AR": 2, "TNP-ND": 3, "TNP-A": 4}
            style_map = {"TNP-D-Ind": '-', "TNP-D-AR": '--', "TNP-A": ':', "TNP-ND": (0, (5, 2)), "Ours": '-.'}
            ordered_h, ordered_l = [], []
            for lab in order:
                if lab in hmap:
                    ordered_h.append(hmap[lab])
                    ordered_l.append(lab)
                else:
                    proxy = Line2D([0], [0], color=cb[color_idx[lab]], linestyle=style_map[lab], linewidth=7.0,
                                   marker=['o','s','^','D','v','p'][len(ordered_h) % 6], markersize=12,
                                   markeredgecolor='white', markeredgewidth=1.5)
                    ordered_h.append(proxy)
                    ordered_l.append(lab)
            legend_ncol = min(5, len(ordered_l))
            fig.legend(ordered_h, ordered_l, loc='lower center', bbox_to_anchor=(0.5, 0.02), ncol=legend_ncol, frameon=False)
        fig.tight_layout(rect=[0, 0.12, 1, 0.98])

        import os
        os.makedirs(out_dir, exist_ok=True)
        out_path = os.path.join(out_dir, f"flex_vs_baselines_B_{Bv}.png")
        fig.savefig(out_path, bbox_inches="tight")
        # Also save PDF for paper-ready if targeting paper_plots
        try:
            if "paper_plots" in os.path.normpath(out_dir):
                out_pdf = os.path.splitext(out_path)[0] + ".pdf"
                fig.savefig(out_pdf, bbox_inches="tight", format="pdf")
        except Exception:
            pass
        out_paths.append(out_path)
        plt.close(fig)

    return out_paths


def main():
    ap = argparse.ArgumentParser(description="Compare Ours (flex fwd+bwd) vs baselines (fwd+bwd) on overlapping grid")
    ap.add_argument("flex_json", type=str, help="Path to flex_forward_backward JSON")
    ap.add_argument("baseline_json", type=str, help="Path to forward_backward_times JSON")
    ap.add_argument("--grid", type=str, default="1x4", help="Subplot grid, e.g., 1x4")
    ap.add_argument("--figsize", type=str, default="15x11", help="Figure size WxH inches")
    ap.add_argument("--yscale", type=str, default="log", choices=["log", "log2", "linear"]) 
    ap.add_argument("--out_dir", type=str, default="outputs/fast_times/new_plots")
    args = ap.parse_args()

    g = tuple(int(x) for x in args.grid.lower().split("x"))
    w, h = (float(x) for x in args.figsize.lower().split("x"))
    paths = plot_flex_vs_baselines(args.flex_json, args.baseline_json, grid=g, figsize=(w, h), out_dir=args.out_dir, yscale=args.yscale)
    print("Saved:")
    for p in paths:
        print(" -", p)


if __name__ == "__main__":
    main()
