from __future__ import annotations

import json
import argparse
from typing import List, Dict, Any, Tuple

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker


def plot_all_methods(
    method_data_list,
    method_labels=None,
    Nt=16,
    figsize=(15, 11),
    color_indices=None,
    grid_shape=(2, 2),
    alpha: float = 0.8,
    method_alphas: List[float] | None = None,
    title_pad: int = 18,
    suptitle: str | None = None,
    suptitle_size: float = 26.0,
    suptitle_y: float = 0.97,
    preferred_font: str | None = None,
    fig=None,
    axes=None,
    num_samples_values: list[int] | None = None,
    legend: str = "bottom",
    finalize_layout: bool = True,
    yscale: str = "log",
):
    """
    Plot all methods with guaranteed power-of-ten y-axis ticks and a dynamically spaced, centered legend.
    (Copied from notebook to preserve appearance: grey background, configurable layout.)
    """

    # Use seaborn colorblind-friendly palette for consistency with training plots
    full_palette = sns.color_palette("colorblind", n_colors=10)

    # Manual Color Selection Logic
    if color_indices:
        try:
            colors = [full_palette[i] for i in color_indices]
        except IndexError:
            print("Error: A color index is out of range for the palette (0-11). Using default colors.")
            colors = full_palette[:len(method_data_list)]
    else:
        colors = full_palette[:len(method_data_list)]

    # Prefer serif fonts with subtle ligatures, typical for ICML-style figures.
    _serif_stack = [
        "TeX Gyre Termes",      # Times-like with ligatures
        "STIX Two Text",
        "Times New Roman",
        "Times",
        "Nimbus Roman",
        "Liberation Serif",
        "DejaVu Serif",
    ]
    if preferred_font:
        _serif_stack = [preferred_font] + [f for f in _serif_stack if f.lower() != preferred_font.lower()]

    # Set the aesthetic theme: white background, no grids, serif font stack
    sns.set_theme(
        style="white",
        font="serif", font_scale=2.2,
        rc={
            "axes.facecolor": "#ffffff",
            "figure.facecolor": "#ffffff",
            "grid.linestyle": "",
            "axes.grid": False,
            "font.family": "serif",
            "font.serif": _serif_stack,
            "axes.titleweight": "bold",
        },
    )

    markers = ['o', 's', '^', 'D', 'v', 'p']
    # Stable line styles by method label (fallback cycles if unknown)
    style_by_prefix = {
        'LL M1': '-', 'M1 ': '-', 'TNPD-Independent': '-', 'TNPD-Ind': '-', 'TNP-D-Ind': '-',
        'LL M2': '--', 'M2 ': '--', 'TNPD-AR': '--', 'TNP-D-AR': '--',
        'LL M3': '-.', 'M3 ': '-.', 'TR M3': '-.', 'Ours': '-.', 'ACE': '-.',
        'LL M4': ':',  'M4 ': ':',  'TNPA': ':',   'TNP-A': ':',
        'LL M5': (0, (5, 2)), 'M5 ': (0, (5, 2)), 'TNP-ND': (0, (5, 2)),
    }
    fallback_styles = ['-', '--', '-.', ':', (0, (5, 2))]

    def pick_style(label: str, idx: int):
        for k, v in style_by_prefix.items():
            if label.startswith(k):
                return v
        return fallback_styles[idx % len(fallback_styles)]

    # Data Conversion
    dfs = []
    for data in method_data_list:
        if isinstance(data, pd.DataFrame):
            dfs.append(data)
        else:
            dfs.append(pd.DataFrame(data))

    # Plotting Setup
    nrows, ncols = grid_shape
    n_plots = nrows * ncols

    all_num_samples = set()
    for df in dfs:
        all_num_samples.update(df['num_samples'].unique())
    sel_num = (sorted(list(all_num_samples))[:n_plots]
               if num_samples_values is None else list(num_samples_values)[:n_plots])

    if method_labels is None:
        method_labels = [f'Method {i + 1}' for i in range(len(dfs))]

    created_fig = False
    if axes is None:
        fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharey=True, squeeze=False)
        created_fig = True
    # Big (figure-level) title if provided and we created the fig
    if suptitle and created_fig:
        fig.suptitle(suptitle, fontsize=suptitle_size, fontweight="bold", y=suptitle_y)
    axes = np.array(axes).flatten().tolist()

    for idx, num_samples in enumerate(sel_num):
        ax = axes[idx]

        for method_idx, (df, label) in enumerate(zip(dfs, method_labels)):
            # drop negative times (e.g., sentinel for OOM)
            data = df[(df['num_samples'] == num_samples) & (df['mean_time'] >= 0)].sort_values('Nc')

            if not data.empty:
                # Choose per-method alpha when provided
                _alpha = alpha
                if method_alphas is not None and method_idx < len(method_alphas) and method_alphas[method_idx] is not None:
                    _alpha = float(method_alphas[method_idx])
                ax.errorbar(
                    data['Nc'], data['mean_time'], yerr=data['std_time'],
                    marker=markers[method_idx % len(markers)],
                    markersize=12, linewidth=7.0, capsize=6,
                    label=label,
                    color=colors[method_idx],
                    elinewidth=4.0, markeredgecolor='white', markeredgewidth=1.5,
                    alpha=_alpha, linestyle=pick_style(label, method_idx)
                )

        # Axis scales
        if yscale == "log":
            ax.set_yscale('log', base=10)
        elif yscale == "log2":
            ax.set_yscale('log', base=2)
        else:
            ax.set_yscale('linear')
        ax.set_xscale('log', base=2)

        # Y ticks formatting
        if yscale == "log":
            ax.yaxis.set_major_locator(mticker.LogLocator(base=10.0, numticks=5))
            ax.yaxis.set_major_formatter(mticker.LogFormatterSciNotation(base=10.0))
        elif yscale == "log2":
            ax.yaxis.set_major_locator(mticker.LogLocator(base=2.0, numticks=6))
            ax.yaxis.set_major_formatter(mticker.LogFormatter(base=2.0))
        else:
            # Let Matplotlib choose sensible linear ticks
            ax.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5))

        ax.minorticks_on()

        ax.set_xlabel('$N_c$')
        if idx % ncols == 0:
            ax.set_ylabel('Time (s)')

        ax.set_title(f'$B = {num_samples}$', fontweight='bold', pad=title_pad)

        x_ticks = [32, 64, 128, 256, 512, 1024]
        ax.set_xticks(x_ticks)
        ax.set_xticklabels(x_ticks)

        # Clean spines: keep left/bottom, hide right/top (classic publication style)
        for spine in ['left', 'bottom']:
            ax.spines[spine].set_linewidth(2)
            ax.spines[spine].set_color('black')
        for spine in ['right', 'top']:
            ax.spines[spine].set_visible(False)

    # Hide any unused subplots
    for i in range(len(sel_num), n_plots):
        axes[i].axis('off')

    # --- DYNAMIC LEGEND AND SPACING ---
    if finalize_layout:
        handles, labels = axes[0].get_legend_handles_labels()

        if legend == 'bottom' and labels:
            # 1. Set legend to wrap if there are more than 5 items
            num_labels = len(labels)
            legend_ncol = num_labels if num_labels <= 5 else 5

            fig.legend(
                handles,
                labels,
                loc='lower center',
                bbox_to_anchor=(0.5, 0.02),
                ncol=legend_ncol,
                frameon=False
            )

            # 2. Dynamically calculate the bottom margin needed for the legend
            num_legend_rows = np.ceil(num_labels / legend_ncol) if num_labels > 0 else 0
            bottom_margin = 0.05 + (num_legend_rows * 0.07)
        else:
            bottom_margin = 0.08

        # 3. Use tight_layout with a rect that respects the calculated margins.
        # Reserve top space based on suptitle_y when provided, so titles don't overlap.
        if suptitle and created_fig:
            # Leave only ~0.005 headroom below the suptitle line so the title sits closer
            top_margin = max(0.80, min(0.98, suptitle_y - 0.005))
        else:
            top_margin = 1.0
        fig.tight_layout(rect=[0, bottom_margin, 1, top_margin])

    return fig


def _load_methods_from_json(paths: List[str]) -> Tuple[List[Dict[str, Any]], List[str]]:
    collect: List[Dict[str, Any]] = []
    labels: List[str] = []
    for p in paths:
        with open(p, "r") as f:
            obj = json.load(f)
        methods = obj.get("methods", {})
        for label, d in methods.items():
            labels.append(label)
            collect.append({
                "Nc": d.get("Nc", []),
                "num_samples": d.get("num_samples", []),
                "mean_time": d.get("mean_time", []),
                "std_time": d.get("std_time", []),
            })
    return collect, labels


def main() -> None:
    ap = argparse.ArgumentParser(description="Plot fast_times JSON results with grey-theme multi-panel")
    ap.add_argument("json", nargs="+", help="One or more JSON result files")
    ap.add_argument("--Nt", type=int, default=16)
    ap.add_argument("--grid", type=str, default="2x2", help="Grid e.g. 2x2, 1x3")
    ap.add_argument("--figsize", type=str, default="15x11", help="Figure size WxH inches")
    ap.add_argument("--colors", type=str, default="", help="Comma list of palette indices (0-11)")
    ap.add_argument("--out", type=str, default="", help="Optional output path for PNG/PDF")
    ap.add_argument("--title", type=str, default="", help="Figure suptitle (big title)")
    ap.add_argument("--title-size", type=float, default=28.0, help="Suptitle font size")
    ap.add_argument("--title-y", type=float, default=0.97, help="Suptitle vertical position [0-1]")
    ap.add_argument("--font", type=str, default="", help="Preferred sans-serif font family name")
    ap.add_argument("--yscale", type=str, default="log", choices=["log", "log2", "linear"], help="Y-axis scale")
    args = ap.parse_args()

    data, labels = _load_methods_from_json(args.json)

    g = tuple(int(x) for x in args.grid.lower().split("x"))
    w, h = (float(x) for x in args.figsize.lower().split("x"))
    color_indices = [int(x) for x in args.colors.split(",")] if args.colors else None

    fig = plot_all_methods(
        data,
        method_labels=labels,
        Nt=args.Nt,
        figsize=(w, h),
        color_indices=color_indices,
        grid_shape=g,
        suptitle=(args.title or None),
        suptitle_size=args.title_size,
        suptitle_y=args.title_y,
        preferred_font=(args.font or None),
        yscale=args.yscale,
    )
    if args.out:
        fig.savefig(args.out, bbox_inches="tight")
    else:
        plt.show()


if __name__ == "__main__":
    main()
