#!/usr/bin/env python3
from __future__ import annotations

"""2x3 grid of sampling-time plots for B (num_samples) = 256, showing only
TNPD-AR and Ours. Repeats the same plot in all 6 panels (two rows, three cols)
for a publication-ready layout.

Usage example:
  uv run python scripts/fast_times/plot_grid_b256_two_methods.py \
    --json outputs/fast_times/results/baseline_sampling.json \
           outputs/fast_times/results/compiled_sampling.json \
           outputs/fast_times/results/triton_sampling.json \
    --out outputs/fast_times/plots/b256_two_methods_grid.png

This script searches across the provided JSONs for the method labels and uses
the shared plot_all_methods() helper for consistent styling.
"""

import argparse
from pathlib import Path
from typing import Dict, Any, List, Tuple

import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import matplotlib.ticker as mticker

def plot_all_methods(
    method_data_list,
    method_labels=None,
    Nt: int = 16,
    figsize=(15, 11),
    color_indices=None,
    grid_shape=(2, 2),
    alpha: float = 0.8,
    title_pad: int = 18,
    suptitle: str | None = None,
    suptitle_size: float = 26.0,
    suptitle_y: float = 0.97,
    preferred_font: str | None = None,
    fig=None,
    axes=None,
    num_samples_values: list[int] | None = None,
    legend: str = "bottom",
    finalize_layout: bool = True,
):
    # Paired palette
    paired_palette_rgb = [
        (166, 206, 227), (31, 120, 180), (178, 223, 138), (51, 160, 44), (251, 154, 153),
        (227, 26, 28), (253, 191, 111), (255, 127, 0), (202, 178, 214), (106, 61, 154),
        (255, 255, 153), (177, 89, 40)
    ]
    full_palette = [(r/255.0, g/255.0, b/255.0) for r, g, b in paired_palette_rgb]
    if color_indices:
        try:
            colors = [full_palette[i] for i in color_indices]
        except IndexError:
            colors = full_palette[:len(method_data_list)]
    else:
        colors = full_palette[:len(method_data_list)]

    _serif_stack = [
        "TeX Gyre Termes", "STIX Two Text", "Times New Roman", "Times",
        "Nimbus Roman", "Liberation Serif", "DejaVu Serif",
    ]
    if preferred_font:
        _serif_stack = [preferred_font] + [f for f in _serif_stack if f.lower() != preferred_font.lower()]
    sns.set_theme(
        style="white",
        font="serif", font_scale=2.2,
        rc={
            "axes.facecolor": "#ffffff",
            "figure.facecolor": "#ffffff",
            "grid.linestyle": "",
            "axes.grid": False,
            "font.family": "serif",
            "font.serif": _serif_stack,
            "axes.titleweight": "bold",
        },
    )

    markers = ['o', 's', '^', 'D', 'v', 'p']
    style_by_prefix = {
        'LL M1': '-', 'M1 ': '-', 'TNPD-Independent': '-', 'TNPD-Ind': '-',
        'LL M2': '--', 'M2 ': '--', 'TNPD-AR': '--',
        'LL M3': '-.', 'M3 ': '-.', 'TR M3': '-.', 'Ours': '-.',
        'LL M4': ':',  'M4 ': ':',  'TNPA': ':',   'TNP-A': ':',
        'LL M5': (0, (5, 2)), 'M5 ': (0, (5, 2)), 'TNP-ND': (0, (5, 2)),
    }
    fallback_styles = ['-', '--', '-.', ':', (0, (5, 2))]
    def pick_style(label: str, idx: int):
        for k, v in style_by_prefix.items():
            if label.startswith(k):
                return v
        return fallback_styles[idx % len(fallback_styles)]

    dfs = [d if isinstance(d, pd.DataFrame) else pd.DataFrame(d) for d in method_data_list]
    nrows, ncols = grid_shape
    n_plots = nrows * ncols
    all_num = set()
    for df in dfs:
        all_num.update(df['num_samples'].unique())
    sel_num = (sorted(list(all_num))[:n_plots] if num_samples_values is None else list(num_samples_values)[:n_plots])
    if method_labels is None:
        method_labels = [f'Method {i+1}' for i in range(len(dfs))]
    created_fig = False
    if axes is None:
        fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharey=True, squeeze=False)
        created_fig = True
    if suptitle and created_fig:
        fig.suptitle(suptitle, fontsize=suptitle_size, fontweight="bold", y=suptitle_y)
    axes = np.array(axes).flatten().tolist()

    for idx, num_samples in enumerate(sel_num):
        ax = axes[idx]
        for method_idx, (df, label) in enumerate(zip(dfs, method_labels)):
            data = df[(df['num_samples'] == num_samples) & (df['mean_time'] >= 0)].sort_values('Nc')
            if not data.empty:
                ax.errorbar(
                    data['Nc'], data['mean_time'], yerr=data['std_time'],
                    marker=markers[method_idx % len(markers)], markersize=12, linewidth=8.0, capsize=6,
                    label=label, color=colors[method_idx], elinewidth=4.0,
                    markeredgecolor='white', markeredgewidth=1.5, alpha=alpha, linestyle=pick_style(label, method_idx)
                )
        ax.set_yscale('log', base=10)
        ax.set_xscale('log', base=2)
        ymin, ymax = ax.get_ylim()
        log_min = np.floor(np.log10(ymin))
        log_max = np.ceil(np.log10(ymax))
        major_ticks = [10 ** i for i in np.arange(log_min, log_max)]
        ax.set_yticks(major_ticks)
        ax.yaxis.set_major_formatter(mticker.LogFormatterSciNotation(base=10))
        ax.minorticks_on()
        ax.set_xlabel('$N_c$')
        if idx % ncols == 0:
            ax.set_ylabel('Time (s)')
        ax.set_title(f'$B = {num_samples}$', fontweight='bold', pad=title_pad)
        x_ticks = [32, 64, 128, 256, 512, 1024]
        ax.set_xticks(x_ticks)
        ax.set_xticklabels(x_ticks)
        for spine in ['left', 'bottom']:
            ax.spines[spine].set_linewidth(2)
            ax.spines[spine].set_color('black')
        for spine in ['right', 'top']:
            ax.spines[spine].set_visible(False)
    for i in range(len(sel_num), n_plots):
        axes[i].axis('off')
    if finalize_layout:
        handles, labels = axes[0].get_legend_handles_labels()
        if legend == 'bottom' and labels:
            num_labels = len(labels)
            legend_ncol = num_labels if num_labels <= 5 else 5
            fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.02), ncol=legend_ncol, frameon=False)
            num_legend_rows = np.ceil(num_labels / legend_ncol) if num_labels > 0 else 0
            bottom_margin = 0.05 + (num_legend_rows * 0.07)
        else:
            bottom_margin = 0.08
        if suptitle and created_fig:
            top_margin = max(0.80, min(0.98, suptitle_y - 0.005))
        else:
            top_margin = 1.0
        fig.tight_layout(rect=[0, bottom_margin, 1, top_margin])
    return fig


WANTED_LABELS = [
    # TNPD-AR labels seen in baseline/compiled files
    "TNPD-AR", "M2 TNP-D AR",
    # Our method labels in compiled/triton sampling
    "Ours", "M3 Ours AR Buffer", "TR M3 Ours AR Buffer",
]


def _load_methods_from_json(paths: List[str]) -> Dict[str, Dict[str, Any]]:
    out: Dict[str, Dict[str, Any]] = {}
    for p in paths:
        with open(p, "r") as f:
            obj = json.load(f)
        for label, d in obj.get("methods", {}).items():
            out[label] = {
                "Nc": d.get("Nc", []),
                "num_samples": d.get("num_samples", []),
                "mean_time": d.get("mean_time", []),
                "std_time": d.get("std_time", []),
            }
    return out


def _pick_two(df_map: Dict[str, Dict[str, Any]]) -> Tuple[Dict[str, Any], Dict[str, Any], str, str]:
    # Find TNPD-AR
    ar_key = None
    for k in ["TNPD-AR", "M2 TNP-D AR"]:
        if k in df_map:
            ar_key = k
            break
    # Find Ours
    ours_key = None
    for k in ["Ours", "M3 Ours AR Buffer", "TR M3 Ours AR Buffer"]:
        if k in df_map:
            ours_key = k
            break
    if ar_key is None or ours_key is None:
        missing = []
        if ar_key is None:
            missing.append("TNPD-AR")
        if ours_key is None:
            missing.append("Ours")
        raise KeyError(f"Missing required methods: {missing}")
    return df_map[ar_key], df_map[ours_key], ar_key, ours_key


def main() -> None:
    # Paper-ready vector outputs
    mpl.rcParams["pdf.fonttype"] = 42
    mpl.rcParams["ps.fonttype"] = 42
    ap = argparse.ArgumentParser(description="2x3 grid for B=256 showing TNPD-AR vs Ours")
    ap.add_argument("--json", nargs="+", required=True, help="One or more sampling JSON result files")
    ap.add_argument("--B", type=int, default=256, help="num_samples (B) to select")
    ap.add_argument("--out", type=str, required=True)
    args = ap.parse_args()

    df_map = _load_methods_from_json(args.json)
    df_ar, df_ours, lbl_ar, lbl_ours = _pick_two(df_map)

    # Filter to B only happens inside plot_all_methods via num_samples_values
    # Prepare data lists in the desired method order
    methods = [df_ar, df_ours]
    labels = [lbl_ar, lbl_ours]

    # Create a 2x3 grid and paint each axis with the same plot
    fig, axes = plt.subplots(2, 3, figsize=(14, 6), sharey=True, squeeze=False)
    axes_list = axes.flatten().tolist()

    # Use Paired palette indices: dark green=3 (TNPD-AR), dark red=5 (Ours)
    color_indices = [3, 5]
    for ax in axes_list:
        plot_all_methods(
            methods,
            method_labels=labels,
            Nt=args.B,
            figsize=(5, 4),
            color_indices=color_indices,
            grid_shape=(1, 1),
            fig=fig,
            axes=[ax],
            num_samples_values=[args.B],
            legend='none',
            finalize_layout=False,
        )

    # Add one combined legend at the bottom
    handles, leg_labels = axes_list[0].get_legend_handles_labels()
    if handles and leg_labels:
        fig.legend(handles, leg_labels, loc='lower center', bbox_to_anchor=(0.5, 0.03), ncol=2, frameon=False)

    fig.tight_layout(rect=[0, 0.08, 1, 1])
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(args.out, bbox_inches="tight")
    # Also save PDF if targeting paper_plots
    try:
        if "paper_plots" in str(Path(args.out)):
            out_pdf = str(Path(args.out).with_suffix('.pdf'))
            fig.savefig(out_pdf, bbox_inches='tight', format='pdf')
    except Exception:
        pass
    print(f"Saved: {args.out}")


if __name__ == "__main__":
    main()
