#!/usr/bin/env python3
from plot_tts import TARGET_LENGTHS, _format_k, _set_plot_style, load_and_aggregate
from plot_capped_response_length import (
    DEFAULT_DISTRIBUTION_BIN_WIDTH,
    DEFAULT_INPUTS,
    DEFAULT_TOKENIZER_PATH,
    FileStats,
    build_binned_distribution,
    compute_file_stats,
    load_tokenizer,
    parse_input_spec,
    smooth_histogram_line,
)
from matplotlib.ticker import FuncFormatter, MaxNLocator
from matplotlib.lines import Line2D
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
from typing import Sequence

import matplotlib

matplotlib.use("Agg")


CSV_DEFAULT = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/"
    "project_tts_extrapolation/results_dec/"
    "datasource_breakdown_results_optimized.csv"
)

FIG_BG_COLOR = "white"
AX_BG_COLOR = "#FAFAFA"
GRID_COLOR = "#D0D0D0"
SPINE_COLOR = "#666666"
TEXT_COLOR = "#333333"

COLOR_BASELINE = "#00468B"
COLOR_GSPO_DIST = "#2E8B57"
COLOR_OURS = "#AE1029"
COLOR_NEUTRAL = "#555555"

SERIES_MARKERS = {"grpo": "o", "gspo": "s", "grpo_high_clip": "^"}
SERIES_LINESTYLES = {"grpo": "-", "gspo": "--", "grpo_high_clip": "--"}
SERIES_DASHES = {"grpo": None, "gspo": (7, 2.2), "grpo_high_clip": (7, 2.2)}

FALLBACK_DIST_COLORS = ["#6C91BF", "#5A9374", "#7A6E9E"]


def _configure_common_axis(ax, tick_fontsize: float = 14.0) -> None:
    ax.set_facecolor(AX_BG_COLOR)
    ax.grid(
        True,
        axis="both",
        alpha=0.35,
        color=GRID_COLOR,
        linewidth=0.9,
        linestyle="-",
        zorder=0,
    )
    ax.set_axisbelow(True)
    ax.tick_params(
        axis="both",
        labelcolor=TEXT_COLOR,
        labelsize=tick_fontsize,
        length=4.5,
        width=1.3,
        color=SPINE_COLOR,
        pad=3.0,
    )
    for side in ["left", "bottom", "top", "right"]:
        ax.spines[side].set_linewidth(1.4)
        ax.spines[side].set_color(SPINE_COLOR)
        ax.spines[side].set_visible(True)


def _build_tts_legend_handles() -> list[Line2D]:
    algorithm_handles = [
        Line2D(
            [0],
            [0],
            color=COLOR_NEUTRAL,
            linewidth=2.2,
            marker="o",
            linestyle="-",
            markersize=7.2,
            markerfacecolor="white",
            markeredgecolor=COLOR_NEUTRAL,
            markeredgewidth=1.4,
            label="GRPO",
        ),
        Line2D(
            [0],
            [0],
            color=COLOR_NEUTRAL,
            linewidth=2.2,
            marker="s",
            linestyle="--",
            markersize=7.2,
            markerfacecolor="white",
            markeredgecolor=COLOR_NEUTRAL,
            markeredgewidth=1.4,
            label="GSPO",
        ),
        Line2D(
            [0],
            [0],
            color=COLOR_NEUTRAL,
            linewidth=2.2,
            marker="^",
            linestyle="--",
            markersize=7.2,
            markerfacecolor="white",
            markeredgecolor=COLOR_NEUTRAL,
            markeredgewidth=1.4,
            label="GRPO w/Clip-higher",
        ),
        Line2D([0], [0], color=COLOR_BASELINE,
               linewidth=2.3, label="Baseline"),
        Line2D([0], [0], color=COLOR_OURS, linewidth=2.3, label="+ LINE"),
        Line2D([0], [0], color="none", linewidth=0.0, label=""),
    ]
    algorithm_handles[1].set_dashes(SERIES_DASHES["gspo"])
    algorithm_handles[2].set_dashes(SERIES_DASHES["grpo_high_clip"])
    return [
        algorithm_handles[0],
        algorithm_handles[3],
        algorithm_handles[1],
        algorithm_handles[4],
        algorithm_handles[2],
        algorithm_handles[5],
    ]


def draw_tts_legend(ax, legend_anchor=(0.45, -0.20)) -> None:
    ax.axis("off")
    ax.legend(
        handles=_build_tts_legend_handles(),
        loc="lower center",
        bbox_to_anchor=legend_anchor,
        ncol=3,
        frameon=True,
        fancybox=True,
        framealpha=0.96,
        edgecolor="#D4D4D4",
        facecolor="white",
        prop={"weight": "bold", "size": 14.0},
        columnspacing=0.9,
        handlelength=1.9,
        handletextpad=0.45,
        borderpad=0.32,
        labelspacing=0.35,
    )


def draw_tts_panel(ax, agg) -> None:
    _configure_common_axis(ax)

    x_pos = list(range(len(TARGET_LENGTHS)))
    x_labels = [_format_k(length) for length in TARGET_LENGTHS]

    if len(x_pos) >= 2:
        ax.axvspan(
            1.0,
            float(len(x_pos) - 0.5),
            facecolor="#D9D9D9",
            alpha=0.10,
            zorder=0,
        )

    for series_key in ["grpo", "gspo", "grpo_high_clip"]:
        for variant, color in [("baseline", COLOR_BASELINE), ("ours", COLOR_OURS)]:
            subset = agg[
                (agg["series"] == series_key) & (agg["variant"] == variant)
            ].set_index("Truncation_Length")
            values = [
                subset.loc[length, "mean_acc"]
                if length in subset.index
                else float("nan")
                for length in TARGET_LENGTHS
            ]
            if all(np.isnan(values)):
                continue

            (line,) = ax.plot(
                x_pos,
                values,
                color=color,
                marker=SERIES_MARKERS[series_key],
                linestyle=SERIES_LINESTYLES[series_key],
                linewidth=2.4,
                markersize=6.2,
                markerfacecolor="white",
                markeredgecolor=color,
                markeredgewidth=1.2,
                zorder=4,
            )
            if SERIES_DASHES[series_key]:
                line.set_dashes(SERIES_DASHES[series_key])

    ax.set_xticks(x_pos, x_labels)
    ax.set_xlabel("Budget", fontsize=16, fontweight="bold", labelpad=8)
    ax.set_ylabel("Accuracy", fontsize=16, fontweight="bold", labelpad=8)
    ax.set_xlim(-0.15, float(len(x_pos) - 1) + 0.18)

    values = np.asarray(agg["mean_acc"], dtype=float)
    values = values[~np.isnan(values)]
    if values.size:
        ymin = max(0.0, float(values.min()) - 0.01)
        ymax = max(0.55, float(values.max()) + 0.008)
    else:
        ymin, ymax = 0.0, 1.0
    ax.set_ylim(ymin, ymax)


def _distribution_style(label: str, idx: int) -> dict[str, object]:
    lower = label.lower()
    if "qwen" in lower or "base" in lower:
        return {"color": COLOR_BASELINE, "linestyle": "-", "dashes": None}
    if "gspo" in lower and any(
        token in lower for token in ("line", "skip-right", "ours")
    ):
        return {"color": COLOR_OURS, "linestyle": "-", "dashes": None}
    if "gspo" in lower:
        return {"color": COLOR_GSPO_DIST, "linestyle": "-", "dashes": None}
    return {
        "color": FALLBACK_DIST_COLORS[idx % len(FALLBACK_DIST_COLORS)],
        "linestyle": "-",
        "dashes": None,
    }


def draw_distribution_panel(
    ax,
    results: Sequence[FileStats],
    distribution_budget: int,
    bin_width: int,
    plot_xmax: int,
    legend_anchor=(0.985, 0.985),
) -> None:
    _configure_common_axis(ax)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=6))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.xaxis.set_major_formatter(
        FuncFormatter(
            lambda value, _: "0" if abs(
                value) < 1e-9 else f"{value / 1000.0:.0f}k"
        )
    )

    plotted_any = False
    y_max = 0.0

    for idx, result in enumerate(results):
        length2count = result.length_frequencies.get(distribution_budget, {})
        if not length2count:
            continue

        binned_distribution = build_binned_distribution(
            length2count,
            budget=distribution_budget,
            bin_width=bin_width,
        )
        if binned_distribution and binned_distribution[-1][1] >= distribution_budget:
            binned_distribution = binned_distribution[:-1]
        binned_distribution = [
            (bin_start, bin_end, count)
            for bin_start, bin_end, count in binned_distribution
            if bin_start < plot_xmax
        ]
        if not binned_distribution:
            continue

        plotted_any = True
        style = _distribution_style(result.label, idx)
        bin_starts = np.asarray(
            [bin_start for bin_start, _, _ in binned_distribution], dtype=float
        )
        shares = np.asarray(
            [
                count / max(result.total_cases, 1) * 100.0
                for _, _, count in binned_distribution
            ],
            dtype=float,
        )
        centers = np.asarray(
            [(bin_start + bin_end) / 2.0 for bin_start,
             bin_end, _ in binned_distribution],
            dtype=float,
        )
        smooth_x, smooth_y = smooth_histogram_line(centers, shares)

        y_max = max(
            y_max,
            float(np.max(shares)) if shares.size else 0.0,
            float(np.max(smooth_y)) if smooth_y.size else 0.0,
        )

        ax.bar(
            bin_starts,
            shares,
            width=bin_width * 0.90,
            align="edge",
            color=style["color"],
            alpha=0.10,
            edgecolor="none",
            zorder=1,
        )
        (line,) = ax.plot(
            smooth_x,
            smooth_y,
            color=style["color"],
            linestyle=style["linestyle"],
            linewidth=2.6,
            alpha=0.98,
            label=result.label,
            zorder=3,
            solid_joinstyle="round",
            solid_capstyle="round",
        )
        if style["dashes"]:
            line.set_dashes(style["dashes"])

    if not plotted_any:
        raise ValueError(
            "No valid distribution data was found for the requested budget.")

    ax.set_xlim(0, plot_xmax)
    ax.set_ylim(0, max(1.0, y_max * 1.12))
    ax.set_xlabel("Length", fontsize=16, fontweight="bold", labelpad=8)
    ax.set_ylabel(
        "Percentage of Samples (%)",
        fontsize=16,
        fontweight="bold",
        labelpad=8,
    )

    ax.legend(
        loc="upper right",
        bbox_to_anchor=legend_anchor,
        ncol=1,
        frameon=True,
        fancybox=True,
        framealpha=0.96,
        edgecolor="#D4D4D4",
        facecolor="white",
        prop={"weight": "bold", "size": 14.0},
        columnspacing=0.95,
        handlelength=1.9,
        handletextpad=0.45,
        borderpad=0.32,
        labelspacing=0.35,
    )


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Generate the combined test-time scaling figure in one script."
    )
    parser.add_argument("--csv", default=CSV_DEFAULT,
                        help="TTS result CSV path.")
    parser.add_argument(
        "--input",
        dest="inputs",
        action="append",
        help="Distribution input spec as PATH::LABEL. Can be passed multiple times.",
    )
    parser.add_argument(
        "--tokenizer",
        default=DEFAULT_TOKENIZER_PATH,
        help=f"Tokenizer path or model id. Default: {DEFAULT_TOKENIZER_PATH}",
    )
    parser.add_argument(
        "--distribution-budget",
        type=int,
        default=32768,
        help="Budget used for the capped-length distribution panel.",
    )
    parser.add_argument(
        "--distribution-bin-width",
        type=int,
        default=DEFAULT_DISTRIBUTION_BIN_WIDTH,
        help=f"Histogram bin width. Default: {DEFAULT_DISTRIBUTION_BIN_WIDTH}",
    )
    parser.add_argument(
        "--plot-xmax",
        type=int,
        default=10000,
        help="Maximum x-axis value shown in the right panel.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=128,
        help="Batch size for tokenizer inference.",
    )
    parser.add_argument(
        "--fig-width",
        type=float,
        default=11.2,
        help="Figure width in inches.",
    )
    parser.add_argument(
        "--fig-height",
        type=float,
        default=4.7,
        help="Figure height in inches.",
    )
    parser.add_argument(
        "--title",
        default=None,
        help="Optional figure title placed above both panels.",
    )
    parser.add_argument(
        "--left-title",
        default=None,
        help="Optional title for the left panel.",
    )
    parser.add_argument(
        "--right-title",
        default=None,
        help="Optional title for the right panel.",
    )
    parser.add_argument(
        "--out",
        default=str(
            Path(__file__).resolve().parent / "plots" /
            "test_time_scaling_combined.pdf"
        ),
        help="Output path for the combined figure.",
    )
    return parser


def _load_distribution_results(
    input_specs: Sequence[str],
    tokenizer_path: str,
    distribution_budget: int,
    batch_size: int,
) -> list[FileStats]:
    print(f"Loading tokenizer from: {tokenizer_path}")
    tokenizer = load_tokenizer(tokenizer_path)

    results: list[FileStats] = []
    for spec in input_specs:
        path, label = parse_input_spec(spec)
        print(f"Processing {label}: {path}")
        results.append(
            compute_file_stats(
                path=path,
                label=label,
                tokenizer=tokenizer,
                budgets=[distribution_budget],
                batch_size=batch_size,
            )
        )
    return results


def main() -> None:
    args = build_parser().parse_args()
    _set_plot_style()

    input_specs = args.inputs or [
        f"{path}::{label}" for path, label in DEFAULT_INPUTS]

    print(f"Loading TTS aggregate CSV: {args.csv}")
    agg = load_and_aggregate(args.csv)
    distribution_results = _load_distribution_results(
        input_specs=input_specs,
        tokenizer_path=args.tokenizer,
        distribution_budget=args.distribution_budget,
        batch_size=args.batch_size,
    )

    fig = plt.figure(figsize=(args.fig_width, args.fig_height))
    fig.patch.set_facecolor(FIG_BG_COLOR)
    grid = fig.add_gridspec(1, 2, width_ratios=[1.02, 0.98], wspace=0.22)
    left_grid = grid[0, 0].subgridspec(
        3, 1, height_ratios=[1.0, 0.10, 0.32], hspace=0.0)

    ax_left = fig.add_subplot(left_grid[0, 0])
    ax_left_gap = fig.add_subplot(left_grid[1, 0])
    ax_left_legend = fig.add_subplot(left_grid[2, 0])
    ax_right = fig.add_subplot(grid[0, 1])

    draw_tts_panel(ax_left, agg)
    ax_left_gap.axis("off")
    draw_tts_legend(ax_left_legend)
    draw_distribution_panel(
        ax_right,
        distribution_results,
        distribution_budget=args.distribution_budget,
        bin_width=args.distribution_bin_width,
        plot_xmax=min(args.plot_xmax, args.distribution_budget),
    )

    if args.left_title:
        ax_left.set_title(args.left_title, fontsize=16,
                          fontweight="bold", pad=18)
    if args.right_title:
        ax_right.set_title(args.right_title, fontsize=16,
                           fontweight="bold", pad=18)
    if args.title:
        fig.suptitle(args.title, fontsize=18, fontweight="bold", y=0.98)
        top = 0.93
    else:
        top = 0.96

    plt.subplots_adjust(left=0.07, right=0.99,
                        bottom=0.10, top=top, wspace=0.22)

    right_pos = ax_right.get_position()
    right_height_scale = 0.88
    new_height = right_pos.height * right_height_scale
    new_y0 = right_pos.y1 - new_height
    ax_right.set_position([right_pos.x0, new_y0, right_pos.width, new_height])

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    save_kwargs = dict(bbox_inches="tight", metadata={"Creator": "matplotlib"})
    if out_path.suffix.lower() == ".pdf":
        fig.savefig(out_path, format="pdf", dpi=600, **save_kwargs)
    else:
        fig.savefig(out_path, dpi=300, **save_kwargs)
    plt.close(fig)
    print(f"[OK] saved to: {out_path}")


if __name__ == "__main__":
    main()
