from matplotlib.lines import Line2D
import pandas as pd
import matplotlib.pyplot as plt
import argparse
import os

import matplotlib

matplotlib.use("Agg")


TARGET_DATASOURCES = ["aime", "aime25", "olympiad_bench", "math", "amc"]
TARGET_LENGTHS = [4096, 8192, 12288, 16384, 24576, 32768]


SERIES = {
    "grpo": {
        "title": "GRPO",
    },
    "gspo": {
        "title": "GSPO",
    },
    "grpo_high_clip": {
        "title": "GRPO w/Clip-higher",
    },
    # "stage2":{
    #     "title": "Stage 2",
    # },
}


def _set_plot_style():
    # 对齐 plot_length_ngram.py 的“论文风”输出
    plt.rcParams["pdf.fonttype"] = 42
    plt.rcParams["ps.fonttype"] = 42
    plt.rcParams["pdf.use14corefonts"] = False
    plt.rcParams["pdf.fonttype"] = 42
    plt.rcParams["ps.fonttype"] = 42
    plt.rcParams["pdf.use14corefonts"] = False
    # 中文字体做 fallback，避免标题/注释出现方块（macOS 优先 PingFang）
    plt.rcParams["font.sans-serif"] = [
        "PingFang SC",
        "Heiti SC",
        "Songti SC",
        "SimHei",
        "Arial",
        "DejaVu Sans",
    ]
    plt.rcParams["axes.unicode_minus"] = False
    plt.rcParams["figure.facecolor"] = "white"
    plt.rcParams["axes.facecolor"] = "#FAFAFA"
    plt.rcParams["savefig.facecolor"] = "white"
    plt.rcParams["font.family"] = "sans-serif"
    plt.rcParams["axes.linewidth"] = 1.8


def _format_k(x: int) -> str:
    return f"{int(x/1024)}k"


def _classify_model(model: str) -> tuple[str, str] | None:
    """
    把 CSV 的 Model 字符串归类到 (series_key, line_key)：
    - series_key: grpo / gspo / grpo_high_clip / stage2
    - line_key: baseline / ours
    设计目标：兼容带 step 的名字（如 "GRPO (Step 430)"）或不带 step（如 "GRPO"）。
    """
    if not isinstance(model, str):
        return None
    m = model.strip()
    if not m:
        return None

    # variant
    line_key = "ours" if "+ ours" in m else "baseline"

    # series
    if "GRPO" in m:
        if "High clip ratio" in m:
            return ("grpo_high_clip", line_key)
        return ("grpo", line_key)
    if "GSPO" in m:
        # 检查是否是 Stage 2
        if "Stage 2" in m or "stage2" in m.lower():
            return ("stage2", line_key)
        return ("gspo", line_key)
    return None


def load_and_aggregate(csv_path: str) -> pd.DataFrame:
    """
    返回一个 DataFrame：
    columns = [series, variant, Truncation_Length, mean_acc]
    其中 mean_acc 是 TARGET_DATASOURCES 的平均 Accuracy，仅保留 TARGET_LENGTHS。
    """
    df = pd.read_csv(csv_path)
    required_cols = {"Model", "Truncation_Length", "Datasource", "Accuracy"}
    missing = required_cols - set(df.columns)
    if missing:
        raise ValueError(f"CSV 缺少列: {sorted(missing)}；实际列: {list(df.columns)}")

    df = df.copy()
    df["Truncation_Length"] = pd.to_numeric(
        df["Truncation_Length"], errors="coerce")
    df["Accuracy"] = pd.to_numeric(df["Accuracy"], errors="coerce")
    df = df.dropna(subset=["Truncation_Length",
                   "Accuracy", "Model", "Datasource"])

    df = df[df["Datasource"].isin(TARGET_DATASOURCES)]
    df = df[df["Truncation_Length"].isin(TARGET_LENGTHS)]

    # 自动归类 Model -> series/variant（避免 step 号或命名细节导致匹配失败）
    cls = df["Model"].apply(_classify_model)
    df["series"] = cls.apply(lambda x: x[0] if x else None)
    df["variant"] = cls.apply(lambda x: x[1] if x else None)
    df = df.dropna(subset=["series", "variant"])
    agg = (
        df.groupby(["series", "variant", "Truncation_Length"],
                   as_index=False)["Accuracy"]
        .mean()
        .rename(columns={"Accuracy": "mean_acc"})
    )
    return agg


def plot_single_panel(agg: pd.DataFrame, out_path: str, title: str | None = None):
    _set_plot_style()

    # 1. 创建画布
    fig, ax = plt.subplots(1, 1, figsize=(8.8, 8.0))
    ax.set_facecolor("#FAFAFA")

    # 2. 基础配置
    colors = {"baseline": "#00468B", "ours": "#AE1029"}
    series_markers = {"grpo": "o", "gspo": "s", "grpo_high_clip": "^"}
    series_linestyles = {"grpo": "-", "gspo": "--", "grpo_high_clip": "--"}
    series_dashes = {"grpo": None, "gspo": (10, 3), "grpo_high_clip": (10, 3)}

    x_pos = list(range(len(TARGET_LENGTHS)))
    x_labels = [_format_k(x) for x in TARGET_LENGTHS]

    # 3. 背景 / 网格
    if len(x_pos) >= 2:
        ax.axvspan(1.0, float(len(x_pos) - 0.5),
                   facecolor="#D9D9D9", alpha=0.18, zorder=0)

    ax.grid(True, axis="both", alpha=0.3, color="#CCCCCC",
            linewidth=0.8, linestyle="-", zorder=0)
    ax.set_axisbelow(True)

    ax.tick_params(axis="both", labelcolor="#333333",
                   labelsize=16, length=6, width=1.5)

    for side in ["left", "bottom", "top", "right"]:
        ax.spines[side].set_linewidth(1.8)
        ax.spines[side].set_color("#666666")
        ax.spines[side].set_visible(True)

    # 4. 画线（不设置 label，legend 完全手动）
    for series_key in ["grpo", "gspo", "grpo_high_clip"]:
        # Baseline
        sub_b = agg[(agg["series"] == series_key) &
                    (agg["variant"] == "baseline")].set_index("Truncation_Length")
        y_b = [sub_b.loc[L, "mean_acc"] if L in sub_b.index else float("nan")
               for L in TARGET_LENGTHS]

        if not all(pd.isna(v) for v in y_b):
            (line,) = ax.plot(
                x_pos, y_b,
                color=colors["baseline"],
                marker=series_markers[series_key],
                linestyle=series_linestyles[series_key],
                linewidth=3.0,
                markersize=8,
                markerfacecolor="white",
                markeredgecolor=colors["baseline"],
                markeredgewidth=1.8,
                zorder=4,
            )
            if series_dashes[series_key]:
                line.set_dashes(series_dashes[series_key])

        # Ours (+ ILE)
        sub_o = agg[(agg["series"] == series_key) &
                    (agg["variant"] == "ours")].set_index("Truncation_Length")
        y_o = [sub_o.loc[L, "mean_acc"] if L in sub_o.index else float("nan")
               for L in TARGET_LENGTHS]

        if not all(pd.isna(v) for v in y_o):
            (line,) = ax.plot(
                x_pos, y_o,
                color=colors["ours"],
                marker=series_markers[series_key],
                linestyle=series_linestyles[series_key],
                linewidth=3.0,
                markersize=8,
                markerfacecolor="white",
                markeredgecolor=colors["ours"],
                markeredgewidth=1.8,
                zorder=4,
            )
            if series_dashes[series_key]:
                line.set_dashes(series_dashes[series_key])

    # 5. 坐标轴
    ax.set_xticks(x_pos, x_labels)
    ax.set_xlabel("Budget", fontsize=18, fontweight="bold", labelpad=10)
    ax.set_ylabel("Accuracy", fontsize=18, fontweight="bold", labelpad=10)
    ax.set_xlim(-0.15, float(len(x_pos) - 1) + 0.25)

    vals = pd.to_numeric(agg["mean_acc"], errors="coerce").dropna()
    if len(vals) == 0:
        y0, y1 = 0.0, 1.0
    else:
        mn, mx = vals.min(), vals.max()
        span = max(mx - mn, 0.04)
        pad = max(0.005, span * 0.03)
        y0, y1 = max(0.0, mn - pad), min(1.0, mx + pad)
    ax.set_ylim(y0, 0.55)

    # =====================================================
    # Legend 第一行：算法（marker + linestyle）
    # =====================================================
    alg_handles = [
        Line2D([0], [0], color=colors["baseline"], linewidth=3.0,
               marker="o", linestyle="-", markersize=9,
               markerfacecolor="white",
               markeredgecolor=colors["baseline"], markeredgewidth=1.8,
               label="GRPO"),

        Line2D([0], [0], color=colors["baseline"], linewidth=3.0,
               marker="s", linestyle="--", markersize=9,
               markerfacecolor="white",
               markeredgecolor=colors["baseline"], markeredgewidth=1.8,
               label="GSPO"),

        Line2D([0], [0], color=colors["baseline"], linewidth=3.0,
               marker="^", linestyle="--", markersize=9,
               markerfacecolor="white",
               markeredgecolor=colors["baseline"], markeredgewidth=1.8,
               label="GRPO w/Clip-higher"),
    ]
    alg_handles[1].set_dashes(series_dashes["gspo"])
    alg_handles[2].set_dashes(series_dashes["grpo_high_clip"])

    leg1 = ax.legend(
        handles=alg_handles,
        loc="lower center",
        bbox_to_anchor=(0.5, -0.35),
        ncol=3,
        frameon=False,
        framealpha=0.95,
        edgecolor="#888888",
        fancybox=True,
        prop={"weight": "bold", "size": 18},
        columnspacing=1.2,
        handletextpad=0.6,
    )
    ax.add_artist(leg1)

    # =====================================================
    # Legend 第二行：Baseline / + ILE（颜色语义）
    # =====================================================
    style_handles = [
        Line2D([0], [0], color=colors["baseline"],
               linewidth=3.0, label="Baseline"),
        Line2D([0], [0], color=colors["ours"], linewidth=3.0, label="+ LINE"),
    ]

    leg2 = ax.legend(
        handles=style_handles,
        loc="lower center",
        bbox_to_anchor=(0.5, -0.45),
        ncol=2,
        frameon=False,
        framealpha=0.95,
        edgecolor="#888888",
        fancybox=True,
        prop={"weight": "bold", "size": 18},
        columnspacing=1.5,
        handletextpad=0.8,
    )

    # 6. 标题 & 布局调整
    if title:
        fig.suptitle(title, fontsize=28, fontweight="bold", y=0.96)

    plt.subplots_adjust(top=0.82, bottom=0.30, left=0.12, right=0.98)

    # =====================================================
    # 添加统一边框包围两个图例（在布局确定后）
    # =====================================================
    from matplotlib.patches import FancyBboxPatch

    # 确保图例已经被渲染，以便正确获取bbox
    fig.canvas.draw()
    renderer = fig.canvas.get_renderer()

    # 获取两个图例的bbox（在显示坐标系中，像素坐标）
    bbox1_display = leg1.get_window_extent(renderer)
    bbox2_display = leg2.get_window_extent(renderer)

    # 转换为figure坐标系（0-1范围）
    bbox1_fig = bbox1_display.transformed(fig.transFigure.inverted())
    bbox2_fig = bbox2_display.transformed(fig.transFigure.inverted())

    # 计算包含两个图例的最小矩形（添加一些padding）
    padding = -0.005  # 边框与图例之间的间距（减小以让边框更紧凑）
    x0 = min(bbox1_fig.x0, bbox2_fig.x0) - padding
    y0 = min(bbox1_fig.y0, bbox2_fig.y0) - padding
    x1 = max(bbox1_fig.x1, bbox2_fig.x1) + padding
    y1 = max(bbox1_fig.y1, bbox2_fig.y1) + padding

    width = x1 - x0
    height = y1 - y0

    # 创建统一边框（使用figure坐标系）
    # 只绘制边框线，不填充，避免遮挡图例内容
    unified_frame = FancyBboxPatch(
        (x0, y0), width, height,
        boxstyle="round,pad=0.01",
        transform=fig.transFigure,
        facecolor="none",  # 不填充，只显示边框
        edgecolor="#888888",
        linewidth=1.5,
        alpha=1.0,
        zorder=0,  # 边框层级，确保在图例下方
        clip_on=False,
    )
    fig.patches.append(unified_frame)

    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    save_kwargs = dict(bbox_inches="tight", metadata={"Creator": "matplotlib"})
    if out_path.lower().endswith(".pdf"):
        fig.savefig(out_path, format="pdf", dpi=600, **save_kwargs)
    else:
        fig.savefig(out_path, dpi=300, **save_kwargs)

    plt.close(fig)


def main():
    parser = argparse.ArgumentParser(
        description="画 TTS/Truncation 的单图 6 条线（baseline/ours 两色；算法用同一 marker 区分）"
    )
    parser.add_argument(
        "--csv", default="/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/datasource_breakdown_results_optimized.csv")
    parser.add_argument("--out", default="plots/tts.pdf",
                        help="1x4 图的输出路径（.pdf）")
    parser.add_argument("--title", default=None, help="整张图的总标题（可选）")
    args = parser.parse_args()

    agg = load_and_aggregate(args.csv)
    plot_single_panel(agg, args.out, title=args.title)
    print(f"[OK] saved to: {args.out}")


if __name__ == "__main__":
    main()
