from __future__ import annotations
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import FixedLocator
from matplotlib.lines import Line2D
import matplotlib.patheffects as pe
import numpy as np
import matplotlib.pyplot as plt

import argparse
import os
from dataclasses import dataclass
from typing import Callable
from typing import Optional

import matplotlib

matplotlib.use("Agg")


plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
plt.rcParams["pdf.use14corefonts"] = False


RUN_COLORS = [
    "#00468B",
    "#9B59B6",
    "#AE1029",
    "#8C564B",
]

FIG_BG_COLOR = "#FFFFFF"
BG_COLOR = "#FBFCFE"
GRID_COLOR = "#D9DEE8"
SPINE_COLOR = "#2F3441"
TEXT_COLOR = "#141821"
POINT_COLOR = RUN_COLORS[1]
BASELINE_COLOR = "#98A2B3"
BASELINE_DASH_STYLE = (0, (4, 3))

ALL_METHOD_LABELS = [
    "GSPO",
    r"$\Delta L=0$",
    r"$\Delta L=100$",
    r"$\Delta L=500$",
    r"$\Delta L=1k$",
    r"$\Delta L=8k$",
]

ALL_METHOD_TICK_LABELS = [
    "GSPO",
    "0",
    "100",
    "500",
    "1k",
    "8k",
]

ALL_METHOD_COLORS = [
    RUN_COLORS[2],
    RUN_COLORS[0],
    "#CDB4E6",
    "#B889D9",
    RUN_COLORS[1],
    "#7D3E98",
]

TABLE_ROWS = [
    {
        "method": ALL_METHOD_LABELS[0],
        "tick_label": ALL_METHOD_TICK_LABELS[0],
        "color": ALL_METHOD_COLORS[0],
        "ngram_count": 1286,
        "distinct_10gram_ratio": 84.3,
        "equation_count": 54.2,
        "length": 3637,
        "valid_acc_16k": 55.7,
    },
    {
        "method": ALL_METHOD_LABELS[1],
        "tick_label": ALL_METHOD_TICK_LABELS[1],
        "color": ALL_METHOD_COLORS[1],
        "ngram_count": 897,
        "distinct_10gram_ratio": 96.1,
        "equation_count": 43.6,
        "length": 2015,
        "valid_acc_16k": 56.9,
    },
    {
        "method": ALL_METHOD_LABELS[2],
        "tick_label": ALL_METHOD_TICK_LABELS[2],
        "color": ALL_METHOD_COLORS[2],
        "ngram_count": 1912,
        "distinct_10gram_ratio": 96.4,
        "equation_count": 80.2,
        "length": 3788,
        "valid_acc_16k": 61.4,
    },
    {
        "method": ALL_METHOD_LABELS[3],
        "tick_label": ALL_METHOD_TICK_LABELS[3],
        "color": ALL_METHOD_COLORS[3],
        "ngram_count": 2354,
        "distinct_10gram_ratio": 96.7,
        "equation_count": 80.9,
        "length": 4588,
        "valid_acc_16k": 61.6,
    },
    {
        "method": ALL_METHOD_LABELS[4],
        "tick_label": ALL_METHOD_TICK_LABELS[4],
        "color": ALL_METHOD_COLORS[4],
        "ngram_count": 2468,
        "distinct_10gram_ratio": 95.8,
        "equation_count": 86.2,
        "length": 4852,
        "valid_acc_16k": 61.6,
    },
    {
        "method": ALL_METHOD_LABELS[5],
        "tick_label": ALL_METHOD_TICK_LABELS[5],
        "color": ALL_METHOD_COLORS[5],
        "ngram_count": 2681,
        "distinct_10gram_ratio": 94.6,
        "equation_count": 86.9,
        "length": 5241,
        "valid_acc_16k": 63.4,
    },
]


@dataclass(frozen=True)
class MetricSpec:
    key: str
    title: str
    formatter: Callable
    label_formatter: Callable[[float], str]
    ymin: Optional[float] = None
    ymax: Optional[float] = None


def _compact_count_formatter(x, _):
    x = float(x)
    abs_x = abs(x)
    if abs_x >= 1_000_000:
        return f"{x / 1_000_000:.2f}M"
    if abs_x >= 1_000:
        return f"{x / 1_000:.1f}k"
    return f"{int(round(x))}"


def _float_formatter_factory(ylim):
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
        span = abs(y1 - y0)
        max_abs = max(abs(y0), abs(y1))
    except Exception:
        span, max_abs = 0.0, 0.0

    if max_abs < 0.1 or span < 0.05:
        fmt = "{:.3f}"
    elif max_abs < 1.0 or span < 0.5:
        fmt = "{:.2f}"
    else:
        fmt = "{:.1f}"

    def _fmt(x, pos):
        try:
            return fmt.format(float(x))
        except Exception:
            return str(x)

    return _fmt


def _add_top_headroom(ylim, frac: float = 0.10, cap_upper=None):
    y0, y1 = float(ylim[0]), float(ylim[1])
    span = max(y1 - y0, 1e-9)
    y1_new = y1 + span * frac
    if cap_upper is not None:
        y1_new = min(float(cap_upper), y1_new)
    return y0, y1_new


def _compute_ylim(values, floor: Optional[float] = None, ceil: Optional[float] = None):
    arr = np.asarray(values, dtype=float)
    finite_vals = arr[np.isfinite(arr)]
    if finite_vals.size == 0:
        lo, hi = 0.0, 1.0
    else:
        lo = float(np.min(finite_vals))
        hi = float(np.max(finite_vals))
        span = hi - lo
        pad_low = max(span * 0.12, 1e-9)
        pad_high = max(span * 0.18, 1e-9)
        if span < 1e-12:
            base = max(abs(hi), 1.0)
            pad_low = base * 0.10
            pad_high = base * 0.14
        lo -= pad_low
        hi += pad_high

    if floor is not None:
        lo = max(float(floor), lo)
    if ceil is not None:
        hi = min(float(ceil), hi)
    if hi <= lo:
        hi = lo + 1e-6

    return _add_top_headroom((lo, hi), frac=0.10, cap_upper=ceil)


def _build_axis_formatter(formatter: Callable, ylim):
    if formatter is _float_formatter_factory:
        return FuncFormatter(formatter(ylim))
    return FuncFormatter(formatter)


def _set_style():
    plt.rcParams["font.family"] = "sans-serif"
    plt.rcParams["axes.linewidth"] = 1.6
    plt.rcParams["figure.facecolor"] = FIG_BG_COLOR
    plt.rcParams["savefig.facecolor"] = FIG_BG_COLOR


def _style_ax(ax, font_tick: int):
    ax.set_facecolor(BG_COLOR)
    for side in ["left", "bottom", "top", "right"]:
        ax.spines[side].set_linewidth(1.6)
        ax.spines[side].set_color(SPINE_COLOR)
        ax.spines[side].set_visible(True)
    ax.tick_params(
        axis="both",
        labelcolor=TEXT_COLOR,
        labelsize=font_tick,
        length=4.8,
        width=1.6,
        color=SPINE_COLOR,
        pad=4.5,
    )
    ax.grid(
        True,
        axis="both",
        alpha=0.65,
        color=GRID_COLOR,
        linewidth=1.4,
        linestyle="-",
        zorder=0,
    )
    ax.set_axisbelow(True)
    ax.yaxis.set_major_locator(MaxNLocator(nbins=4, integer=False))


def _get_baseline_row() -> dict:
    for row in TABLE_ROWS:
        if row["method"] == "GSPO":
            return row
    raise KeyError("Cannot find baseline row 'GSPO'.")


def _build_display_rows() -> list[dict]:
    rows = []
    for row in TABLE_ROWS:
        if row["method"] == "GSPO":
            continue
        rows.append(row)
    return rows


def _make_legend_handles(display_rows: list[dict], baseline_row: dict):
    handles = []
    handles.append(
        Line2D(
            [0, 1],
            [0, 0],
            color=baseline_row["color"],
            linewidth=3.4,
            solid_capstyle="round",
            label=baseline_row["method"],
        )
    )
    for row in display_rows:
        handles.append(
            Line2D(
                [0],
                [0],
                marker="o",
                linestyle="None",
                markersize=12,
                markerfacecolor=row["color"],
                markeredgecolor="white",
                markeredgewidth=2.2,
                label=row["method"],
            )
        )
    return handles


def _plot_metric_panel(
    ax,
    display_rows: list[dict],
    x_positions: np.ndarray,
    values: np.ndarray,
    baseline_value: float,
    spec: MetricSpec,
    font_tick: int,
    font_title: int,
):
    _style_ax(ax, font_tick=font_tick)
    ylim = _compute_ylim(
        np.concatenate([values, np.asarray([baseline_value], dtype=float)]),
        floor=spec.ymin,
        ceil=spec.ymax,
    )
    label_offset = max((ylim[1] - ylim[0]) * 0.035, 1e-9)

    ax.set_ylim(ylim)
    ax.set_xlim(-0.45, len(x_positions) - 0.55)
    ax.xaxis.set_major_locator(FixedLocator(x_positions))
    ax.set_xticklabels(
        [row["tick_label"] for row in display_rows],
        rotation=25,
        ha="right",
        rotation_mode="anchor",
    )
    ax.yaxis.set_major_formatter(_build_axis_formatter(spec.formatter, ylim))
    ax.set_title(
        spec.title,
        fontsize=font_title,
        fontweight="bold",
        pad=12,
        color=TEXT_COLOR,
    )

    ax.axhline(
        baseline_value,
        color=BASELINE_COLOR,
        linewidth=2.6,
        zorder=1.6,
        alpha=0.95,
        solid_capstyle="round",
        linestyle=BASELINE_DASH_STYLE,
    )

    ax.plot(
        x_positions,
        values,
        color="white",
        linewidth=5.8,
        zorder=2,
        alpha=0.90,
        solid_capstyle="round",
        solid_joinstyle="round",
    )
    ax.plot(
        x_positions,
        values,
        color=SPINE_COLOR,
        linewidth=2.8,
        zorder=3,
        alpha=0.92,
        solid_capstyle="round",
        solid_joinstyle="round",
    )

    ax.scatter(
        x_positions,
        values,
        s=430,
        color=POINT_COLOR,
        alpha=0.16,
        linewidths=0,
        zorder=3.2,
    )
    ax.scatter(
        x_positions,
        values,
        s=215,
        color=POINT_COLOR,
        edgecolors="white",
        linewidths=2.2,
        zorder=4,
    )

    for x_pos, value, row in zip(x_positions, values, display_rows):
        ax.text(
            x_pos,
            value + label_offset,
            spec.label_formatter(float(value)),
            ha="center",
            va="bottom",
            fontsize=max(12, font_tick - 2),
            fontweight="bold",
            color=POINT_COLOR,
            zorder=5,
        )


def _add_baseline_legend(ax):
    handle = Line2D(
        [0, 1],
        [0, 0],
        color=BASELINE_COLOR,
        linewidth=3.2,
        linestyle=BASELINE_DASH_STYLE,
        solid_capstyle="round",
        alpha=0.95,
        label="GSPO",
    )
    legend = ax.legend(
        handles=[handle],
        loc="upper right",
        frameon=True,
        framealpha=0.98,
        fontsize=16,
        handlelength=2.8,
        handletextpad=0.6,
        borderaxespad=0.25,
        fancybox=True,
    )
    legend.get_frame().set_facecolor("#FFFFFF")
    legend.get_frame().set_edgecolor(SPINE_COLOR)
    legend.get_frame().set_linewidth(1.8)
    legend.get_frame().set_path_effects(
        [
            pe.SimplePatchShadow(offset=(1.0, -1.0), alpha=0.12),
            pe.Normal(),
        ]
    )
    for legend_handle in legend.legend_handles:
        if isinstance(legend_handle, Line2D):
            legend_handle.set_path_effects(
                [
                    pe.Stroke(linewidth=4.0, foreground=SPINE_COLOR),
                    pe.Normal(),
                ]
            )
    for text in legend.get_texts():
        text.set_color(BASELINE_COLOR)
        text.set_fontweight("bold")


def plot_gspo_length_table_metrics(
    output_prefix: str,
    dpi: int = 600,
):
    _set_style()
    baseline_row = _get_baseline_row()
    display_rows = _build_display_rows()

    metrics = [
        MetricSpec(
            key="ngram_count",
            title="DNC",
            formatter=_compact_count_formatter,
            label_formatter=lambda x: f"{int(round(x))}",
        ),
        MetricSpec(
            key="equation_count",
            title="DEC",
            formatter=_float_formatter_factory,
            label_formatter=lambda x: f"{x:.1f}",
        ),
        MetricSpec(
            key="length",
            title=r"$\boldsymbol{L}$",
            formatter=_compact_count_formatter,
            label_formatter=lambda x: f"{int(round(x))}",
        ),
        MetricSpec(
            key="distinct_10gram_ratio",
            title="DNR",
            formatter=_float_formatter_factory,
            label_formatter=lambda x: f"{x:.1f}",
        ),
        MetricSpec(
            key="valid_acc_16k",
            title="16k Valid Acc",
            formatter=_float_formatter_factory,
            label_formatter=lambda x: f"{x:.1f}",
        )
    ]

    font_tick = 18
    font_xlabel = 35
    font_title = 30
    x_positions = np.arange(len(display_rows), dtype=float)
    fig, axes = plt.subplots(
        1,
        len(metrics),
        figsize=(28.6, 6.0),
        squeeze=False,
        gridspec_kw={"wspace": 0.19},
    )
    fig.patch.set_facecolor(FIG_BG_COLOR)
    axes = axes.flatten()

    for ax, spec in zip(axes, metrics):
        values = np.asarray([row[spec.key]
                            for row in display_rows], dtype=float)
        baseline_value = float(baseline_row[spec.key])
        _plot_metric_panel(
            ax=ax,
            display_rows=display_rows,
            x_positions=x_positions,
            values=values,
            baseline_value=baseline_value,
            spec=spec,
            font_tick=font_tick,
            font_title=font_title,
        )

    fig.supxlabel(
        r"$\boldsymbol{\Delta L}$",
        fontsize=font_xlabel,
        fontweight="bold",
        y=-0.05,
        color=TEXT_COLOR,
    )
    _add_baseline_legend(axes[-1])

    plt.subplots_adjust(
        left=0.045,
        right=0.992,
        top=0.85,
        bottom=0.19,
        wspace=0.19,
    )

    output_dir = os.path.dirname(output_prefix)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    for extension in ("pdf", "png"):
        output_path = f"{output_prefix}.{extension}"
        plt.savefig(
            output_path,
            dpi=dpi,
            bbox_inches="tight",
            pad_inches=0.18,
            facecolor=fig.get_facecolor(),
            metadata={"Creator": "matplotlib", "Producer": "matplotlib"},
        )
        print(f"Saved figure: {output_path}")

    plt.close(fig)


def build_default_output_prefix():
    here = os.path.dirname(os.path.abspath(__file__))
    return os.path.abspath(
        os.path.join(
            here,
            "results/plots/gspo_length_table_metrics",
        )
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Plot the provided GSPO/Delta-L summary table with the analysis_mar house style."
    )
    parser.add_argument(
        "--output-prefix",
        type=str,
        default=build_default_output_prefix(),
        help="Output path prefix without extension. The script writes both .pdf and .png.",
    )
    parser.add_argument("--dpi", type=int, default=600, help="Figure DPI.")
    args = parser.parse_args()

    plot_gspo_length_table_metrics(
        output_prefix=args.output_prefix,
        dpi=args.dpi,
    )
