from __future__ import annotations
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import FuncFormatter
import numpy as np
import matplotlib.pyplot as plt

import argparse
from dataclasses import dataclass
import json
import os
from typing import Callable
from typing import Sequence
from typing import Optional

import matplotlib

matplotlib.use("Agg")


try:
    from scipy.signal import savgol_filter

    HAS_SCIPY = True
except Exception:
    HAS_SCIPY = False


plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
plt.rcParams["pdf.use14corefonts"] = False


RUN_COLORS = [
    '#00468B',  # 深蓝色 (Diversity)
    '#9B59B6',  # 优雅紫色 (Length)
    "#AE1029",
    "#8c564b",

]

COLOR_BASELINE = RUN_COLORS[0]
COLOR_GSPO_LENGTH = RUN_COLORS[1]
COLOR_GSPO = RUN_COLORS[2]
FIG_BG_COLOR = "#FFFFFF"
BG_COLOR = "#FBFCFE"
GRID_COLOR = "#D9DEE8"
SPINE_COLOR = "#2F3441"
TEXT_COLOR = "#141821"
FALLBACK_COLORS = [
    RUN_COLORS[3],
    "#6C91BF",
    "#C9A227",
    "#5C4B8A",
    "#2C7A7B",
]


DEFAULT_RUN_STYLE_MAP = {
    "baseline": {"label": "Baseline", "color": COLOR_BASELINE},
    "gspolength": {
        "label": r"GSPO + $R_{\mathrm{len}}(\tau)$",
        "color": COLOR_GSPO_LENGTH,
    },
    "gspo": {"label": "GSPO", "color": COLOR_GSPO},
    "gspolie": {"label": "GSPO + LIE", "color": COLOR_GSPO_LENGTH},
}


@dataclass(frozen=True)
class MetricSpec:
    metric_name: str
    title: str
    formatter: Callable
    ymin: Optional[float]
    ymax: Optional[float]
    band_scale: float
    required: bool = True
    formatter_is_factory: bool = False
    step_stride: Optional[int] = None


@dataclass(frozen=True)
class RunSpec:
    run_key: str
    label: str
    color: str


def _thousands_formatter(x, _):
    return f"{x / 1000:.1f}k"


def _percent_formatter(x, _):
    return f"{x * 100:.0f}"


def _compact_count_formatter(x, _):
    x = float(x)
    abs_x = abs(x)
    if abs_x >= 1_000_000:
        return f"{x / 1_000_000:.2f}M"
    if abs_x >= 1_000:
        return f"{x / 1_000:.1f}k"
    return f"{int(round(x))}"


def _ratio_percent_formatter_factory(ylim):
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
        span_pct = abs((y1 - y0) * 100.0)
        max_pct = max(abs(y0), abs(y1)) * 100.0
    except Exception:
        span_pct, max_pct = 0.0, 0.0

    if max_pct < 1.0 or span_pct < 0.5:
        fmt = "{:.2f}"
    elif max_pct < 10.0 or span_pct < 3.0:
        fmt = "{:.1f}"
    else:
        fmt = "{:.0f}"

    def _fmt(x, pos):
        try:
            return fmt.format(float(x) * 100.0)
        except Exception:
            return str(x)

    return _fmt


def _float_formatter_factory(ylim):
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
        span = abs(y1 - y0)
        max_abs = max(abs(y0), abs(y1))
    except Exception:
        span, max_abs = 0.0, 0.0

    if max_abs < 0.1 or span < 0.05:
        fmt = "{:.3f}"
    elif max_abs < 1.0 or span < 0.5:
        fmt = "{:.2f}"
    else:
        fmt = "{:.1f}"

    def _fmt(x, pos):
        try:
            return fmt.format(float(x))
        except Exception:
            return str(x)

    return _fmt


def _smooth_data(values, window_size: int = 5):
    data = np.asarray(values, dtype=float)
    if len(data) < 3:
        return data

    window_size = max(3, int(window_size))
    if window_size % 2 == 0:
        window_size += 1
    if window_size > len(data):
        window_size = len(data) if len(data) % 2 == 1 else len(data) - 1
    if window_size < 3:
        return data

    if HAS_SCIPY:
        try:
            poly_order = min(3, window_size - 1)
            return savgol_filter(data, window_size, poly_order)
        except Exception:
            pass

    kernel = np.ones(window_size) / window_size
    smoothed = np.convolve(data, kernel, mode="same")
    half = window_size // 2
    smoothed[:half] = data[:half]
    smoothed[-half:] = data[-half:]
    return smoothed


def _set_style():
    plt.rcParams["font.family"] = "sans-serif"
    plt.rcParams["axes.linewidth"] = 2
    plt.rcParams["figure.facecolor"] = FIG_BG_COLOR
    plt.rcParams["savefig.facecolor"] = FIG_BG_COLOR


def _style_ax(ax, font_tick: int):
    ax.set_facecolor(BG_COLOR)
    for side in ["left", "bottom", "top", "right"]:
        ax.spines[side].set_linewidth(2)
        ax.spines[side].set_color(SPINE_COLOR)
        ax.spines[side].set_visible(True)
    ax.tick_params(
        axis="both",
        labelcolor=TEXT_COLOR,
        labelsize=font_tick,
        length=5.2,
        width=2,
        color=SPINE_COLOR,
        pad=4.5,
    )
    ax.grid(
        True,
        axis="both",
        alpha=0.65,
        color=GRID_COLOR,
        linewidth=2,
        linestyle="-",
        zorder=0,
    )
    ax.set_axisbelow(True)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5, integer=True))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=4, integer=False))


def _add_top_headroom(ylim, frac: float = 0.08, cap_upper=None):
    y0, y1 = float(ylim[0]), float(ylim[1])
    span = max(y1 - y0, 1e-9)
    y1_new = y1 + span * frac
    if cap_upper is not None:
        y1_new = min(float(cap_upper), y1_new)
    return y0, y1_new


def _compute_ylim(
    values1,
    values2,
    floor: Optional[float] = None,
    ceil: Optional[float] = None,
):
    all_vals = np.concatenate(
        [np.asarray(values1, dtype=float), np.asarray(values2, dtype=float)]
    )
    finite_vals = all_vals[np.isfinite(all_vals)]
    if finite_vals.size == 0:
        lo, hi = 0.0, 1.0
    else:
        lo = float(np.min(finite_vals))
        hi = float(np.max(finite_vals))
        span = hi - lo
        pad_low = max(span * 0.08, 1e-9)
        pad_high = max(span * 0.12, 1e-9)
        if span < 1e-12:
            base = max(abs(hi), 1.0)
            pad_low = base * 0.08
            pad_high = base * 0.12
        lo = lo - pad_low
        hi = hi + pad_high

    if floor is not None:
        lo = max(floor, lo)
    if ceil is not None:
        hi = min(ceil, hi)
    if hi <= lo:
        hi = lo + 1e-6

    return _add_top_headroom((lo, hi), frac=0.08, cap_upper=ceil)


def _compute_ylim_multi(
    value_groups: Sequence[Sequence[float]],
    floor: Optional[float] = None,
    ceil: Optional[float] = None,
):
    arrays = [np.asarray(values, dtype=float)
              for values in value_groups if len(values) > 0]
    if not arrays:
        return _add_top_headroom((0.0, 1.0), frac=0.08, cap_upper=ceil)

    all_vals = np.concatenate(arrays)
    finite_vals = all_vals[np.isfinite(all_vals)]
    if finite_vals.size == 0:
        lo, hi = 0.0, 1.0
    else:
        lo = float(np.min(finite_vals))
        hi = float(np.max(finite_vals))
        span = hi - lo
        pad_low = max(span * 0.08, 1e-9)
        pad_high = max(span * 0.12, 1e-9)
        if span < 1e-12:
            base = max(abs(hi), 1.0)
            pad_low = base * 0.08
            pad_high = base * 0.12
        lo = lo - pad_low
        hi = hi + pad_high

    if floor is not None:
        lo = max(floor, lo)
    if ceil is not None:
        hi = min(ceil, hi)
    if hi <= lo:
        hi = lo + 1e-6

    return _add_top_headroom((lo, hi), frac=0.08, cap_upper=ceil)


def _compute_xlim(all_steps):
    if not all_steps:
        return None

    xmin = min(all_steps)
    xmax = max(all_steps)
    span = max(xmax - xmin, 1)
    pad = max(int(round(span * 0.04)), 12)
    return xmin - pad, xmax + pad


def _compute_band(values, scale: float = 1.0):
    arr = np.asarray(values, dtype=float)
    if len(arr) < 2:
        return np.zeros_like(arr)

    window = min(5, max(1, len(arr) // 10))
    if window <= 1:
        return np.zeros_like(arr)

    diffs = np.abs(np.diff(np.concatenate([[arr[0]], arr])))
    band = np.convolve(diffs, np.ones(window) / window, mode="same")
    return band * scale


def _choose_grid(num_metrics: int) -> tuple[int, int]:
    if num_metrics <= 0:
        return 1, 1
    if num_metrics <= 4:
        ncols = num_metrics
    elif num_metrics <= 6:
        ncols = 3
    elif num_metrics <= 8:
        ncols = 4
    else:
        ncols = 5
    nrows = int(np.ceil(num_metrics / ncols))
    return nrows, ncols


def _normalize_key(name: str) -> str:
    return "".join(ch.lower() for ch in str(name) if ch.isalnum())


def _infer_run_style(run_key: str, series_idx: int):
    normalized = _normalize_key(run_key)
    if normalized in DEFAULT_RUN_STYLE_MAP:
        default_label = DEFAULT_RUN_STYLE_MAP[normalized]["label"]
    else:
        default_label = run_key

    return {
        "label": default_label,
        "color": RUN_COLORS[series_idx % len(RUN_COLORS)]
        if series_idx < len(RUN_COLORS)
        else FALLBACK_COLORS[(series_idx - len(RUN_COLORS)) % len(FALLBACK_COLORS)],
    }


def _load_and_merge_results(input_paths: Sequence[str]):
    merged_results = {}
    key_to_source = {}

    for input_path in input_paths:
        with open(input_path, "r", encoding="utf-8") as handle:
            payload = json.load(handle)

        if not isinstance(payload, dict):
            raise ValueError(
                f"Input JSON top-level must be dict/object: {input_path}"
            )

        for run_key, rows in payload.items():
            if run_key in merged_results:
                raise KeyError(
                    f"Duplicate run key '{run_key}' found in both "
                    f"'{key_to_source[run_key]}' and '{input_path}'."
                )
            merged_results[run_key] = rows
            key_to_source[run_key] = input_path

    return merged_results, key_to_source


def _resolve_run_specs(
    available_run_keys: Sequence[str],
    run_keys: Sequence[str],
    run_labels: Optional[Sequence[str]] = None,
):
    available = set(available_run_keys)
    resolved_labels = list(run_labels) if run_labels else []
    if resolved_labels and len(resolved_labels) != len(run_keys):
        raise ValueError(
            f"run_labels length ({len(resolved_labels)}) must match run_keys length ({len(run_keys)})."
        )

    run_specs = []
    for idx, run_key in enumerate(run_keys):
        if run_key not in available:
            raise KeyError(
                f"Missing run key: {run_key}. Available keys: {sorted(available)}"
            )

        style = _infer_run_style(run_key, idx)
        label = resolved_labels[idx] if resolved_labels else style.get(
            "label", run_key)
        color = style["color"]
        run_specs.append(RunSpec(run_key=run_key, label=label, color=color))

    return run_specs


def _pick_metric(records: list[dict], candidates: list[str]) -> Optional[str]:
    for record in records:
        for key in candidates:
            if key in record:
                return key
    return None


def _load_series(
    results: dict,
    run_key: str,
    metric_key: str,
    max_step: Optional[int],
    step_stride: Optional[int] = None,
):
    if run_key not in results:
        raise KeyError(f"Missing run key: {run_key}")

    rows = sorted(results[run_key], key=lambda x: x.get("step", 0))
    points = []
    for row in rows:
        try:
            step = int(row.get("step", 0))
        except (TypeError, ValueError):
            continue

        if max_step is not None and step > max_step:
            continue

        value = row.get(metric_key, np.nan)
        try:
            value = float(value)
        except (TypeError, ValueError):
            continue
        if not np.isfinite(value):
            continue

        points.append((step, value))

    if step_stride and step_stride > 1 and points:
        residue_counts = {}
        for step, _ in points:
            residue = step % step_stride
            residue_counts[residue] = residue_counts.get(residue, 0) + 1

        max_count = max(residue_counts.values())
        candidate_residues = [
            residue
            for residue, count in residue_counts.items()
            if count == max_count
        ]
        chosen_residue = 0 if 0 in candidate_residues else min(
            candidate_residues)
        points = [
            (step, value)
            for step, value in points
            if step % step_stride == chosen_residue
        ]

    steps = [step for step, _ in points]
    values = [value for _, value in points]
    return steps, values


def _plot_metric_series(
    ax,
    steps,
    values,
    color: str,
    label: str,
    smooth_window: int,
    band_scale: float,
    lower_clip: Optional[float] = None,
    upper_clip: Optional[float] = None,
):
    raw = np.asarray(values, dtype=float)
    smoothed = _smooth_data(raw, smooth_window)
    band = _compute_band(raw, scale=band_scale)

    lower = smoothed - band
    upper = smoothed + band
    if lower_clip is not None:
        lower = np.maximum(lower, lower_clip)
    if upper_clip is not None:
        upper = np.minimum(upper, upper_clip)

    ax.fill_between(
        steps,
        lower,
        upper,
        color=color,
        alpha=0.18,
        zorder=1,
    )
    ax.plot(
        steps,
        smoothed,
        color="white",
        linewidth=5,
        zorder=3,
        alpha=0.88,
        solid_capstyle="round",
        solid_joinstyle="round",
    )
    handle = ax.plot(
        steps,
        smoothed,
        color=color,
        linewidth=4,
        zorder=4,
        label=label,
        solid_capstyle="round",
        solid_joinstyle="round",
    )[0]
    return handle


def _build_axis_formatter(formatter: Callable, ylim):
    if formatter in {
        _ratio_percent_formatter_factory,
        _float_formatter_factory,
    }:
        return FuncFormatter(formatter(ylim))
    return FuncFormatter(formatter)


def plot_baseline_vs_gspo_length(
    input_path: str,
    output_path: str,
    baseline_key: str = "baseline",
    gspo_length_key: str = "gspo_length",
    baseline_label: str = "Baseline",
    gspo_label: str = r"GSPO + $R_{\mathrm{len}}(\tau)$",
    smooth_window: int = 5,
    max_step: Optional[int] = None,
    reasoning_step_stride: int = 100,
    dpi: int = 600,
    input_paths: Optional[Sequence[str]] = None,
    run_keys: Optional[Sequence[str]] = None,
    run_labels: Optional[Sequence[str]] = None,
):
    if input_paths is None:
        input_paths = [input_path]
    else:
        input_paths = list(input_paths)
        if not input_paths:
            input_paths = [input_path]

    results, key_to_source = _load_and_merge_results(input_paths)

    if run_keys is None:
        run_specs = _resolve_run_specs(
            available_run_keys=list(results.keys()),
            run_keys=[baseline_key, gspo_length_key],
            run_labels=[baseline_label, gspo_label],
        )
    else:
        run_specs = _resolve_run_specs(
            available_run_keys=list(results.keys()),
            run_keys=list(run_keys),
            run_labels=run_labels,
        )

    for run_spec in run_specs:
        print(
            f"[info] run='{run_spec.run_key}' from '{key_to_source[run_spec.run_key]}'")

    records_for_detection = []
    for run_spec in run_specs:
        records_for_detection.extend(results.get(run_spec.run_key, []))

    metric_candidates = {
        "reward_acc": ["acc_reward", "reward_acc", "reward_accuracy"],
        "length": ["avg_token_length", "length", "avg_length"],
        "accuracy": ["accuracy", "acc"],
        "entropy": ["entropy", "actor/entropy"],
        "distinct_10gram_count": ["distinct_10gram_count", "10gram_count"],
        "distinct_10gram_ratio": ["distinct_10gram_ratio", "10gram_ratio"],
        "internal_textual_diversity": [
            "internal_textual_diversity",
            "internal_text_diversity",
            "intra_td",
        ],
        "formula_unique_count": [
            "formula_unique_count",
            "intra_edc",
        ],
        "reasoning_behavior_type_count": [
            "reasoning_behavior_type_count",
            "behavior_type_count",
        ],
        "reasoning_behavior_count": [
            "reasoning_behavior_count",
            "behavior_count",
        ],
    }

    reasoning_stride = reasoning_step_stride if reasoning_step_stride > 1 else None
    all_metric_specs = [
        # MetricSpec(
        #     "entropy",
        #     "Entropy",
        #     _float_formatter_factory,
        #     0.0,
        #     None,
        #     0.45,
        #     required=False,
        #     formatter_is_factory=True,
        # ),
        # MetricSpec("reward_acc", "Reward",
        #            _percent_formatter, 0.0, 1.0, 0.40),
        # MetricSpec("length", r"$\boldsymbol{L}$",
        #            _thousands_formatter, 0.0, None, 1.20),
        # MetricSpec("accuracy", "Valid Acc",
        #            _percent_formatter, 0.0, 0.63, 0.40),

        MetricSpec("length", r"$\boldsymbol{L}$",
                   _thousands_formatter, 0.0, None, 1.20),
        MetricSpec("accuracy", "Valid Acc",
                   _percent_formatter, 0.0, 0.63, 0.40),
        MetricSpec(
            "distinct_10gram_count",
            r"DNC",
            _compact_count_formatter,
            0.0,
            None,
            1.20,
        ),
        # MetricSpec(
        #     "internal_textual_diversity",
        #     "TD",
        #     _ratio_percent_formatter_factory,
        #     0.0,
        #     1.0,
        #     0.45,
        #     required=False,
        #     formatter_is_factory=True,
        # ),
        # MetricSpec(
        #     "formula_unique_count",
        #     "DEC",
        #     _compact_count_formatter,
        #     0.0,
        #     None,
        #     0.90,
        #     required=False,
        # ),

        MetricSpec(
            "distinct_10gram_ratio",
            "Distnct N-gram Ratio (DNR)",
            _ratio_percent_formatter_factory,
            0.0,
            1.0,
            0.50,
            formatter_is_factory=True,
        )

        # MetricSpec(
        #     "reasoning_behavior_type_count",
        #     "RBC",
        #     _compact_count_formatter,
        #     0.0,
        #     None,
        #     0.90,
        #     required=False,
        #     step_stride=reasoning_stride,
        # ),
        # # MetricSpec(
        #     "reasoning_behavior_count",
        #     "RB Count",
        #     _compact_count_formatter,
        #     0.0,
        #     None,
        #     0.90,
        #     required=False,
        #     step_stride=reasoning_stride,
        # ),
    ]

    metric_keys = {}
    active_metric_specs = []
    for spec in all_metric_specs:
        candidates = metric_candidates[spec.metric_name]
        key = _pick_metric(records_for_detection, candidates)
        if key is None:
            if spec.required:
                raise KeyError(
                    f"Cannot find metric '{spec.metric_name}' in input records."
                )
            print(
                f"[info] Optional metric '{spec.metric_name}' not found in input records; skip."
            )
            continue
        metric_keys[spec.metric_name] = key
        active_metric_specs.append(spec)

    _set_style()
    font_tick = 20
    font_xlabel = 20
    font_title = 30
    font_legend = 28
    # font_tick = 18
    # font_xlabel = 18
    # font_title = 28
    # font_legend = 25

    num_metrics = len(active_metric_specs)
    nrows, ncols = _choose_grid(num_metrics)
    panel_width = 7
    panel_height = 6
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(panel_width * ncols, panel_height * nrows + 0.95),
        squeeze=False,
        sharex=True,
    )
    fig.patch.set_facecolor(FIG_BG_COLOR)
    axes = axes.flatten()

    legend_handle_map = {}
    all_x = []
    for ax, spec in zip(axes, active_metric_specs):
        key = metric_keys[spec.metric_name]
        _style_ax(ax, font_tick=font_tick)
        metric_value_groups = []
        metric_handles = []

        for run_spec in run_specs:
            steps, values = _load_series(
                results,
                run_spec.run_key,
                key,
                max_step,
                step_stride=spec.step_stride,
            )
            if not steps or not values:
                print(
                    f"[warn] No valid points for run='{run_spec.run_key}' metric='{key}', skip this series."
                )
                continue

            handle = _plot_metric_series(
                ax=ax,
                steps=steps,
                values=values,
                color=run_spec.color,
                label=run_spec.label,
                smooth_window=smooth_window,
                band_scale=spec.band_scale,
                lower_clip=spec.ymin,
                upper_clip=spec.ymax,
            )
            metric_handles.append(handle)
            legend_handle_map.setdefault(run_spec.run_key, handle)
            metric_value_groups.append(values)
            all_x.extend(steps)

        if not metric_handles:
            ax.set_visible(False)
            continue

        ylim = _compute_ylim_multi(
            metric_value_groups,
            floor=spec.ymin,
            ceil=spec.ymax,
        )
        if spec.metric_name == "accuracy" and spec.ymax is not None:
            ylim = (ylim[0], float(spec.ymax))
        ax.set_ylim(ylim)
        ax.set_title(
            spec.title,
            fontsize=font_title,
            fontweight="bold",
            pad=12,
            color=TEXT_COLOR,
        )

        ax.yaxis.set_major_formatter(
            _build_axis_formatter(spec.formatter, ylim))

    for ax in axes[num_metrics:]:
        ax.remove()

    if all_x:
        xmin, xmax = _compute_xlim(all_x)
        bottom_row_start = max(0, (nrows - 1) * ncols)
        for idx, ax in enumerate(axes[:num_metrics]):
            ax.set_xlim(xmin, xmax)
            ax.margins(x=0.01)
            if idx < bottom_row_start:
                ax.tick_params(labelbottom=False)

    # fig.supxlabel(
    #     "Training Step",
    #     fontsize=font_xlabel,
    #     fontweight="bold",
    #     y=0.066,
    #     color=TEXT_COLOR,
    # )

    legend_handles = [
        legend_handle_map[run_spec.run_key]
        for run_spec in run_specs
        if run_spec.run_key in legend_handle_map
    ]
    if not legend_handles:
        raise ValueError("No valid run series were plotted.")

    legend = fig.legend(
        legend_handles,
        [handle.get_label() for handle in legend_handles],
        loc="lower center",
        # bbox_to_anchor=(0.5, -0.06),
        bbox_to_anchor=(0.5, 0.0),
        ncol=max(1, len(legend_handles)),
        frameon=True,
        framealpha=0.98,
        edgecolor="#D5DAE3",
        fancybox=True,
        shadow=False,
        prop={"weight": "bold", "size": font_legend},
        borderpad=0.60,
        columnspacing=1.9,
        handlelength=2.4,
        handletextpad=0.7,
    )
    legend.get_frame().set_linewidth(2.0)
    legend.get_frame().set_facecolor("#FFFFFF")

    plt.subplots_adjust(
        left=0.060,
        right=0.985,
        top=0.89 if nrows > 1 else 0.82,
        bottom=0.155 if nrows > 1 else 0.235,
        wspace=0.2,
        hspace=0.30,
    )
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    save_format = "pdf" if output_path.lower().endswith(".pdf") else None
    plt.savefig(
        output_path,
        format=save_format,
        dpi=dpi,
        bbox_inches="tight",
        pad_inches=0.18,
        facecolor=fig.get_facecolor(),
        metadata={"Creator": "matplotlib", "Producer": "matplotlib"},
    )
    plt.close(fig)
    print(f"Saved figure: {output_path}")


def build_default_paths():
    here = os.path.dirname(os.path.abspath(__file__))
    default_input = os.path.abspath(
        os.path.join(
            here, "../analysis_mar/results/baseline-length-analysis_results.json")
    )
    default_output = os.path.abspath(
        os.path.join(
            here,
            "../analysis_mar/results/plots/baseline_vs_gspo_length_metrics.pdf",
        )
    )
    return default_input, default_output


if __name__ == "__main__":
    default_input, default_output = build_default_paths()
    parser = argparse.ArgumentParser(
        description=(
            "Plot selected runs from one or more analysis JSON files "
            "with plot_length_ngram-like styling."
        )
    )
    parser.add_argument(
        "--input",
        "-i",
        type=str,
        action="append",
        default=None,
        help=(
            "Path to analysis JSON. Can be provided multiple times. "
            "If omitted, use the default single input."
        ),
    )
    parser.add_argument("--output", "-o", type=str,
                        default=default_output, help="Output path (.pdf/.png).")
    parser.add_argument(
        "--run-key",
        type=str,
        action="append",
        default=None,
        help=(
            "Top-level dict key of a run to plot. Can be provided multiple times. "
            "Keys are resolved after merging all input JSONs."
        ),
    )
    parser.add_argument(
        "--run-label",
        type=str,
        action="append",
        default=None,
        help=(
            "Legend label for the corresponding --run-key. "
            "If omitted, use the default style map label or the run key itself."
        ),
    )
    parser.add_argument("--baseline-key", type=str,
                        default="baseline", help="Run key for baseline.")
    parser.add_argument("--gspo-key", type=str,
                        default="gspo_length", help="Run key for gspo length.")
    parser.add_argument("--baseline-label", type=str,
                        default="Baseline", help="Legend label for baseline.")
    parser.add_argument("--gspo-label", type=str,
                        default=r"GSPO + $R_{\mathrm{len}}$", help="Legend label for gspo length.")
    parser.add_argument("--smooth-window", type=int,
                        default=5, help="Smoothing window size.")
    parser.add_argument("--max-step", type=int, default=700,
                        help="Only show steps <= max_step.")
    parser.add_argument(
        "--reasoning-step-stride",
        type=int,
        default=100,
        help=(
            "For reasoning_behavior_* metrics, keep one step every this many steps. "
            "Use 1 to disable the filter."
        ),
    )
    parser.add_argument("--dpi", type=int, default=600, help="Figure DPI.")
    args = parser.parse_args()

    input_paths = args.input if args.input else [default_input]

    plot_baseline_vs_gspo_length(
        input_path=input_paths[0],
        output_path=args.output,
        baseline_key=args.baseline_key,
        gspo_length_key=args.gspo_key,
        baseline_label=args.baseline_label,
        gspo_label=args.gspo_label,
        smooth_window=args.smooth_window,
        max_step=args.max_step,
        reasoning_step_stride=args.reasoning_step_stride,
        dpi=args.dpi,
        input_paths=input_paths,
        run_keys=args.run_key,
        run_labels=args.run_label,
    )
