#!/usr/bin/env python3
from __future__ import annotations

"""Composite 2x3 figure: two columns of B=256 sampling-time plots (TNPD-AR vs Ours)
and a third column of smoothed sequence plots (top: green AR file, bottom: red Buffer file).

Example:
  uv run python scripts/fast_times/plot_composite_2x3.py \
    --json outputs/fast_times/results/baseline_sampling.json \
           outputs/fast_times/results/compiled_sampling.json \
           outputs/fast_times/results/triton_sampling.json \
    --smooth-top outputs/tabular_model_smoothed_ar_sequences.pt \
    --smooth-bottom outputs/tabular_model_smoothed_arbuffer_sequences.pt \
    --fn-index 0 \
    --out outputs/fast_times/plots/composite_2x3.png
"""

import argparse
import json
import os
import glob
from pathlib import Path
from typing import Dict, Any, List, Tuple

import numpy as np
import pandas as pd
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mticker


# ------------------------------- Style & Palette -------------------------------
_SERIF_STACK = [
    "TeX Gyre Termes", "STIX Two Text", "Times New Roman", "Times",
    "Nimbus Roman", "Liberation Serif", "DejaVu Serif",
]
_PAIRED_RGB = [
    (166, 206, 227), (31, 120, 180), (178, 223, 138), (51, 160, 44),
    (251, 154, 153), (227, 26, 28), (253, 191, 111), (255, 127, 0),
    (202, 178, 214), (106, 61, 154), (255, 255, 153), (177, 89, 40),
]
_PALETTE = [(r/255.0, g/255.0, b/255.0) for r, g, b in _PAIRED_RGB]


def _apply_theme() -> None:
    # Paper-ready vector outputs
    mpl.rcParams["pdf.fonttype"] = 42
    mpl.rcParams["ps.fonttype"] = 42
    sns.set_theme(
        style="white",
        font="serif", font_scale=2.4,
        rc={
            "axes.facecolor": "#ffffff",
            "figure.facecolor": "#ffffff",
            "grid.linestyle": "",
            "axes.grid": False,
            "font.family": "serif",
            "font.serif": _SERIF_STACK,
            "axes.titleweight": "bold",
        },
    )


# ---------------------------- Load & Select Methods ----------------------------
def _load_methods_from_json(paths: List[str]) -> Dict[str, Dict[str, Any]]:
    out: Dict[str, Dict[str, Any]] = {}
    for p in paths:
        with open(p, "r") as f:
            obj = json.load(f)
        for label, d in obj.get("methods", {}).items():
            out[label] = {
                "Nc": d.get("Nc", []),
                "num_samples": d.get("num_samples", []),
                "mean_time": d.get("mean_time", []),
                "std_time": d.get("std_time", []),
            }
    return out


def _pick_two(df_map: Dict[str, Dict[str, Any]]) -> Tuple[Dict[str, Any], Dict[str, Any], str, str]:
    # Prefer canonical labels but accept variants
    ar_key = None
    for k in ["TNPD-AR", "M2 TNP-D AR"]:
        if k in df_map:
            ar_key = k
            break
    ours_key = None
    for k in ["Ours", "M3 Ours AR Buffer", "TR M3 Ours AR Buffer"]:
        if k in df_map:
            ours_key = k
            break
    if ar_key is None or ours_key is None:
        missing = []
        if ar_key is None:
            missing.append("TNPD-AR")
        if ours_key is None:
            missing.append("Ours")
        raise KeyError(f"Missing required methods: {missing}")
    return df_map[ar_key], df_map[ours_key], ar_key, ours_key


def _resolve_pt(path_str: str) -> Path:
    p = Path(path_str)
    if p.is_file():
        return p
    if p.is_dir():
        cands = sorted(glob.glob(os.path.join(str(p), '*.pt')))
        if not cands:
            raise FileNotFoundError(f"No .pt files found in {p}")
        cands.sort(key=lambda x: os.path.getmtime(x), reverse=True)
        return Path(cands[0])
    raise FileNotFoundError(f"Path not found: {p}")


# ------------------------------ Render: Sampling -------------------------------
def _render_sampling_panel(ax: plt.Axes, df_ar: Dict[str, Any], df_ours: Dict[str, Any], B: int, *, yscale: str = "log") -> None:
    # Colors and markers
    colors = [_PALETTE[3], _PALETTE[5]]  # TNP-D-AR green(3), Ours red(5)
    labels = ["TNP-D-AR", "Ours"]
    markers = ['o', 's']
    lines = ['--', '-']
    for i, (df, lbl) in enumerate(zip([df_ar, df_ours], labels)):
        Nc = np.array(df["Nc"]).astype(float)
        Bvals = np.array(df["num_samples"]).astype(int)
        mean = np.array(df["mean_time"]).astype(float)
        std = np.array(df["std_time"]).astype(float)
        mask = (Bvals == B) & (mean >= 0)
        if not mask.any():
            continue
        Nc_sel = Nc[mask]
        mean_sel = mean[mask]
        std_sel = std[mask]
        order = np.argsort(Nc_sel)
        Nc_sel = Nc_sel[order]
        mean_sel = mean_sel[order]
        std_sel = std_sel[order]
        ax.errorbar(
            Nc_sel, mean_sel, yerr=std_sel,
            marker=markers[i], markersize=10, linewidth=4.5, capsize=5,
            label=lbl, color=colors[i], elinewidth=3.0,
            markeredgecolor='white', markeredgewidth=1.2, alpha=0.9,
            linestyle=lines[i],
        )
    if yscale == 'log':
        ax.set_yscale('log', base=10)
        ax.yaxis.set_major_locator(mticker.LogLocator(base=10.0, numticks=5))
        ax.yaxis.set_major_formatter(mticker.LogFormatterSciNotation(base=10.0))
    elif yscale == 'log2':
        ax.set_yscale('log', base=2)
        ax.yaxis.set_major_locator(mticker.LogLocator(base=2.0, numticks=6))
        ax.yaxis.set_major_formatter(mticker.LogFormatter(base=2.0))
    else:
        ax.set_yscale('linear')
        ax.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5))
    ax.set_xscale('log', base=2)
    ax.minorticks_on()
    ax.set_xlabel('$N$', fontsize=18, labelpad=8)
    ax.set_ylabel('Time', fontsize=18, labelpad=12)
    ax.tick_params(axis='both', which='both', labelsize=18)
    ax.set_title(f'$B = {B}$', fontweight='bold', pad=6, fontsize=18)
    x_ticks = [32, 64, 128, 256, 512, 1024]
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_ticks)
    for spine in ['left', 'bottom']:
        ax.spines[spine].set_linewidth(2)
        ax.spines[spine].set_color('black')
    for spine in ['right', 'top']:
        ax.spines[spine].set_visible(False)


# ------------------------------- Render: Smooth -------------------------------
def _load_smoothed_one(pth: Path, fn_index: int):
    obj = torch.load(pth, map_location='cpu')
    def to_np(key):
        t = obj.get(key, None)
        if t is None:
            return None
        return t.numpy()
    xc = to_np('xc'); yc = to_np('yc'); xt = to_np('xt'); yt = to_np('yt')
    y_mean = to_np('y_mean'); lo = to_np('ci_lower'); hi = to_np('ci_upper'); y_samps = to_np('y_samples')
    n = xc.shape[0]
    i = max(0, min(fn_index, n-1))
    lo_i = lo[i] if lo is not None else None
    hi_i = hi[i] if hi is not None else None
    samp_i = y_samps[i] if (y_samps is not None and y_samps.size > 0) else None
    return xc[i], yc[i], xt[i], yt[i], y_mean[i], lo_i, hi_i, samp_i


def _render_smoothed_panel(ax: plt.Axes, xc, yc, xt, yt, y_mean, ci_lower, ci_upper, y_samples,
                           mean_color, ci_fill_color, sample_color) -> None:
    order = np.argsort(xt[:, 0])
    xs = xt[order, 0]
    # True function
    ax.plot(xs, yt[order, 0], color='black', linewidth=3.0, alpha=1.0, linestyle='--')
    # CI
    if ci_lower is not None and ci_upper is not None:
        lo = ci_lower[order, 0]
        hi = ci_upper[order, 0]
        ax.fill_between(xs, lo, hi, color=ci_fill_color, alpha=0.18, linewidth=0)
        ax.plot(xs, lo, color=mean_color, linewidth=2.6, linestyle=':')
        ax.plot(xs, hi, color=mean_color, linewidth=2.6, linestyle=':')
    # Samples
    if y_samples is not None and y_samples.size > 0:
        for s in range(y_samples.shape[0]):
            ax.plot(xs, y_samples[s, order, 0], color=sample_color, linewidth=1.4, alpha=0.35)
    # Mean
    ax.plot(xs, y_mean[order, 0], color=mean_color, linewidth=3.0, alpha=0.95)
    # Context plus markers
    ax.plot(xc[:, 0], yc[:, 0], linestyle='None', marker='+', markersize=16, markeredgewidth=3.5, color='black', zorder=6)
    # Clean axes
    ax.set_xlabel(""); ax.set_ylabel("")
    ax.set_xticks([]); ax.set_yticks([])
    ax.tick_params(axis='both', which='both', length=0)
    ax.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)


# ------------------------------------- Main ------------------------------------
def main() -> None:
    ap = argparse.ArgumentParser(description="Composite 2x3: sampling-time (B=256) + smoothed sequences column")
    ap.add_argument("--json", nargs="+", required=True, help="One or more sampling JSON result files")
    ap.add_argument("--B", type=int, default=256)
    ap.add_argument("--yscale", type=str, default="log", choices=["log", "log2", "linear"], help="Y-axis scale for timing panel")
    g = ap.add_mutually_exclusive_group(required=True)
    g.add_argument("--quad", nargs=4, metavar=('F1','F2','F3','F4'),
                   help='Four smoothed .pt files/dirs for the 2x2 predictions occupying the first two columns: (1,3) left/green; (2,4) right/red.')
    g.add_argument("--smooth-top", help="(Deprecated) .pt file or directory (smoothed AR) for top-right column")
    ap.add_argument("--smooth-bottom", help="(Deprecated) .pt file or directory (smoothed Buffer) for bottom-right column")
    ap.add_argument("--fn-index", type=int, default=None, help="Global function index (fallback)")
    ap.add_argument("--fn-index-top", type=int, default=None, help="Function index for top row (overrides --fn-index)")
    ap.add_argument("--fn-index-bottom", type=int, default=None, help="Function index for bottom row (overrides --fn-index)")
    ap.add_argument("--fn-index-left", type=int, default=None, help="Function index for left column (overrides row/global)")
    ap.add_argument("--fn-index-right", type=int, default=None, help="Function index for right column (overrides row/global)")
    ap.add_argument("--out", type=str, required=True)
    args = ap.parse_args()

    _apply_theme()

    # Load sampling results and select methods
    df_map = _load_methods_from_json(args.json)
    df_ar, df_ours, _, _ = _pick_two(df_map)

    # Prepare smoothed data for first two columns (2x2) — quad mode or legacy two-file mode
    # Resolve function indices with precedence: per-column > per-row > global > 0
    fn_global = args.fn_index if args.fn_index is not None else 0
    fn_top = args.fn_index_top if args.fn_index_top is not None else fn_global
    fn_bot = args.fn_index_bottom if args.fn_index_bottom is not None else fn_global
    fn_left = args.fn_index_left
    fn_right = args.fn_index_right

    if args.quad:
        p1, p2, p3, p4 = map(_resolve_pt, args.quad)
        # Top-left (file1, green)
        idx_1 = fn_left if fn_left is not None else fn_top
        xc1,yc1,xt1,yt1,y1,lo1,hi1,s1 = _load_smoothed_one(p1, idx_1)
        # Top-middle (file2, red)
        idx_2 = fn_right if fn_right is not None else fn_top
        xc2,yc2,xt2,yt2,y2,lo2,hi2,s2 = _load_smoothed_one(p2, idx_2)
        # Bottom-left (file3, green)
        idx_3 = fn_left if fn_left is not None else fn_bot
        xc3,yc3,xt3,yt3,y3,lo3,hi3,s3 = _load_smoothed_one(p3, idx_3)
        # Bottom-middle (file4, red)
        idx_4 = fn_right if fn_right is not None else fn_bot
        xc4,yc4,xt4,yt4,y4,lo4,hi4,s4 = _load_smoothed_one(p4, idx_4)
    else:
        # Back-compat: two files — top (green) and bottom (red) for the rightmost of the two prediction columns
        p_top = _resolve_pt(args.smooth_top)
        p_bot = _resolve_pt(args.smooth_bottom)
        idx_1 = fn_left if fn_left is not None else fn_top
        idx_2 = fn_right if fn_right is not None else fn_top
        idx_3 = fn_left if fn_left is not None else fn_bot
        idx_4 = fn_right if fn_right is not None else fn_bot
        xc1,yc1,xt1,yt1,y1,lo1,hi1,s1 = _load_smoothed_one(p_top, idx_1)
        xc2,yc2,xt2,yt2,y2,lo2,hi2,s2 = _load_smoothed_one(p_top, idx_2)
        xc3,yc3,xt3,yt3,y3,lo3,hi3,s3 = _load_smoothed_one(p_bot, idx_3)
        xc4,yc4,xt4,yt4,y4,lo4,hi4,s4 = _load_smoothed_one(p_bot, idx_4)

    # Figure: 2 rows x 3 cols (do not share y across all; we'll sync col1/col2 manually)
    fig, axes = plt.subplots(2, 3, figsize=(18, 6), squeeze=False)
    fig.patch.set_facecolor('white')

    # Columns 1-2: smoothed sequences quad
    # Top row: left (file1, green), middle (file2, red)
    _render_smoothed_panel(axes[0, 0], xc1,yc1,xt1,yt1,y1,lo1,hi1,s1,
                           mean_color=_PALETTE[3], ci_fill_color=_PALETTE[2], sample_color=_PALETTE[3])
    _render_smoothed_panel(axes[0, 1], xc2,yc2,xt2,yt2,y2,lo2,hi2,s2,
                           mean_color=_PALETTE[5], ci_fill_color=_PALETTE[4], sample_color=_PALETTE[5])
    # Bottom row: left (file3, green), middle (file4, red)
    _render_smoothed_panel(axes[1, 0], xc3,yc3,xt3,yt3,y3,lo3,hi3,s3,
                           mean_color=_PALETTE[3], ci_fill_color=_PALETTE[2], sample_color=_PALETTE[3])
    _render_smoothed_panel(axes[1, 1], xc4,yc4,xt4,yt4,y4,lo4,hi4,s4,
                           mean_color=_PALETTE[5], ci_fill_color=_PALETTE[4], sample_color=_PALETTE[5])

    # Column 3: sampling-time panel repeated on both rows
    _render_sampling_panel(axes[0, 2], df_ar, df_ours, args.B, yscale=args.yscale)
    _render_sampling_panel(axes[1, 2], df_ar, df_ours, args.B, yscale=args.yscale)

    # Tidy labels per your spec:
    # - Remove the B=256 titles on the right column and x-label under the top-right sampling plot
    axes[0, 2].set_title("")
    axes[1, 2].set_title("")
    axes[0, 2].set_xlabel("")
    # - First two columns share the same y-axis within each row; only leftmost shows y ticks/label
    for r in range(2):
        y0_lo, y0_hi = axes[r, 0].get_ylim()
        y1_lo, y1_hi = axes[r, 1].get_ylim()
        y_lo = min(y0_lo, y1_lo)
        y_hi = max(y0_hi, y1_hi)
        axes[r, 0].set_ylim(y_lo, y_hi)
        axes[r, 1].set_ylim(y_lo, y_hi)
        # Leftmost (predictions): show y label and ticks
        axes[r, 0].set_ylabel(r'$y$', fontsize=18, labelpad=10)
        axes[r, 0].tick_params(axis='y', which='both', left=True, labelleft=True, labelsize=18, length=0)
        axes[r, 0].yaxis.set_major_locator(mticker.MaxNLocator(nbins=4))
        # Middle predictions: hide y ticks/label
        axes[r, 1].set_ylabel("")
        axes[r, 1].tick_params(axis='y', which='both', left=False, labelleft=False)
        # Right (sampling): keep y ticks/label as provided by the sampling renderer
        axes[r, 2].tick_params(axis='y', which='both', left=True, labelleft=True, labelsize=18)

    # X labels/ticks only on bottom-most plots: predictions use x; sampling keeps N
    # Top row: hide x ticks/labels across all columns
    for c in range(3):
        axes[0, c].set_xlabel("")
        axes[0, c].tick_params(axis='x', which='both', bottom=False, labelbottom=False)
    # Bottom row: show on all columns; predictions use x in LaTeX, sampling already set to N
    for c in (0, 1):
        axes[1, c].set_xlabel(r'$x$', fontsize=18, labelpad=8)
        axes[1, c].tick_params(axis='x', which='both', bottom=True, labelbottom=True, labelsize=18, length=0)
        axes[1, c].xaxis.set_major_locator(mticker.MaxNLocator(nbins=4))
    # Ensure sampling bottom keeps its x ticks/label; make size consistent
    axes[1, 2].tick_params(axis='x', which='both', bottom=True, labelbottom=True, labelsize=18, length=0)
    # Improve third-column y scale: fewer major ticks, hide tick marks but keep labels
    for r in (0, 1):
        axes[r, 2].yaxis.set_major_locator(mticker.LogLocator(base=10.0, numticks=4))
        axes[r, 2].tick_params(axis='y', which='both', length=0)
    # Tighten spacing between first two columns (and overall)
    fig.subplots_adjust(wspace=0.03)

    # Column titles (smaller font) over first two columns
    axes[0, 0].set_title('Autoregressive', fontweight='bold', fontsize=18, pad=3)
    axes[0, 1].set_title('Ours', fontweight='bold', fontsize=18, pad=3)

    # Layout before placing row labels so positions are accurate
    fig.tight_layout(rect=[0, 0.08, 1, 1])

    # Row labels on the far left of the figure (fixed x in figure coords)
    bbox_top = axes[0, 0].get_position(fig)
    bbox_bot = axes[1, 0].get_position(fig)
    y_top = 0.5 * (bbox_top.y0 + bbox_top.y1)
    y_bot = 0.5 * (bbox_bot.y0 + bbox_bot.y1)
    x_lab = 0.012  # very left of the plot area (slightly further left)
    fig.text(x_lab, y_top, 'GP prior', rotation=90,
             va='center', ha='center', fontsize=18, fontweight='bold')
    fig.text(x_lab, y_bot, 'SCM prior', rotation=90,
             va='center', ha='center', fontsize=18, fontweight='bold')

    # Bottom legend from first sampling axis
    handles, leg_labels = axes[0, 0].get_legend_handles_labels()
    if handles and leg_labels:
        fig.legend(handles, leg_labels, loc='lower center', bbox_to_anchor=(0.5, 0.03), ncol=2, frameon=False)

    fig.tight_layout(rect=[0, 0.08, 1, 1])
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(args.out, bbox_inches='tight')
    # Also save PDF when targeting paper_plots
    try:
        import os as _os
        if "paper_plots" in _os.path.normpath(args.out):
            out_pdf = str(Path(args.out).with_suffix('.pdf'))
            fig.savefig(out_pdf, bbox_inches='tight', format='pdf')
    except Exception:
        pass
    print(f"Saved: {args.out}")


if __name__ == '__main__':
    main()
