
from __future__ import annotations

import argparse
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import pandas as pd
import numpy as np


@dataclass(frozen=True)
class RunSpec:
    group: str
    label: str
    run_path: str  # entity/project/run_id
    reward_key: str


# 你可以在这里直接改 reward_key（如果 “ours” 不等于 add1k）
DEFAULT_RUNS: List[RunSpec] = [
    # 组1
    RunSpec(
        group="组1",
        label="GSPO",
        run_path="astrid_tuning_llm/verl-qwen3-4b-oct/s66rgswr",
        reward_key="critic/score/mean",
    ),
    RunSpec(
        group="组1",
        label="GSPO (Ours)",
        run_path="astrid_tuning_llm/verl-qwen3-4b-oct/5xobvm33",
        reward_key="critic/acc/mean",
    ),
    # 组2
    RunSpec(
        group="组2",
        label="GSPO",
        run_path="astrid_tuning_llm/verl-qwen3-4b-oct/ljjuqwxz",
        reward_key="critic/score/mean",
    ),
    RunSpec(
        group="组2",
        label="GSPO (Ours)",
        run_path="astrid_tuning_llm/verl-qwen3-4b-oct/kwhhaomr",
        reward_key="critic/acc/mean",
    ),
]

DEFAULT_METRICS = [
    "avg_score/16384/four_sets",
    "response_length/mean",
]


def _ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


def _pick_x_key(df: pd.DataFrame, prefer: Optional[str] = None) -> str:
    # wandb 常见 step 字段（不同代码库可能不同）
    candidates = []
    if prefer:
        candidates.append(prefer)
    candidates += [
        "trainer/global_step",
        "global_step",
        "_step",
        "step",
        "train/global_step",
    ]
    for k in candidates:
        if k in df.columns:
            return k
    raise ValueError(f"找不到可用的 step 列，现有列：{list(df.columns)[:50]}")


def _safe_numeric(df: pd.DataFrame, cols: Iterable[str]) -> pd.DataFrame:
    for c in cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    return df


def fetch_run_history(
    api,  # wandb.Api
    run_path: str,
    keys: List[str],
) -> pd.DataFrame:
    """
    用 scan_history 尽量全量拉取（history() 可能被采样/截断）。
    """
    run = api.run(run_path)
    rows = []
    # 重要：如果 scan_history(keys=...) 里包含不存在的 key，某些情况下会返回空。
    # 这里改为不传 keys，全量扫描后再做列筛选，保证稳健。
    for row in run.scan_history(page_size=1000):
        if row is None:
            continue
        rows.append(dict(row))
    if not rows:
        return pd.DataFrame(columns=keys)
    df = pd.DataFrame(rows)
    keep_cols = [c for c in keys if c in df.columns]
    if keep_cols:
        return df[keep_cols].copy()
    return df


def maybe_smooth(series: pd.Series, window: int) -> pd.Series:
    if window <= 1:
        return series
    return series.rolling(window=window, min_periods=max(1, window // 3)).mean()


def thousands_formatter(x, pos):
    """将数字格式化为k格式，500以上都用k"""
    return f'{x/1000:.1f}k'


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--out-dir",
        type=str,
        default="_wandb_plots",
        help="输出目录（会创建一个时间戳子目录）",
    )
    parser.add_argument(
        "--x-key",
        type=str,
        default="",
        help="强制使用的 x 轴列名（默认自动探测，如 trainer/global_step 或 _step）",
    )
    parser.add_argument(
        "--smooth",
        type=int,
        default=5,
        help="滑动平均窗口（1 表示不平滑）",
    )
    parser.add_argument(
        "--max-points",
        type=int,
        default=0,
        help="每条曲线最多保留多少点（0 表示不截断；>0 会等距采样，便于画图更快）",
    )
    args = parser.parse_args()

    # 延迟 import，避免在没有 wandb 的环境里直接 ImportError
    import wandb  # noqa: WPS433

    # 基础检查：WANDB_API_KEY 不一定必须（本机已 wandb login 也行）
    if not os.environ.get("WANDB_API_KEY"):
        print("提示：未检测到环境变量 WANDB_API_KEY，如果你已 `wandb login` 可忽略。")

    out_root = Path(args.out_dir).expanduser().resolve()
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_dir = out_root / ts
    _ensure_dir(out_dir)

    # 需要拉取的 keys：两个固定指标 + 每个 run 的 reward_key + 常见 step 字段
    reward_keys = sorted({r.reward_key for r in DEFAULT_RUNS})
    keys = list(DEFAULT_METRICS) + reward_keys + [
        "_step",
        "trainer/global_step",
        "global_step",
        "step",
        "train/global_step",
    ]
    keys = sorted(set(keys))

    print(f"连接 wandb API，准备拉取 {len(DEFAULT_RUNS)} 个 runs 的曲线…")
    api = wandb.Api()

    per_run_df: Dict[str, pd.DataFrame] = {}
    for r in DEFAULT_RUNS:
        print(f"- 拉取 {r.label}: {r.run_path}")
        df = fetch_run_history(api=api, run_path=r.run_path, keys=keys)
        # 保留必要列（避免 scan_history 返回一些空列）
        keep_cols = [c for c in keys if c in df.columns]
        df = df[keep_cols].copy()
        x_key = _pick_x_key(df, prefer=args.x_key if args.x_key else None)
        metric_cols = [*DEFAULT_METRICS, r.reward_key]
        metric_cols = [c for c in metric_cols if c in df.columns]
        df = _safe_numeric(df, [x_key, *metric_cols])
        df = df.dropna(subset=[x_key]).sort_values(x_key)
        if args.max_points and len(df) > args.max_points:
            # 等距采样（保首尾）
            idx = np.linspace(0, len(df) - 1, args.max_points).astype(int)
            df = df.iloc[idx].copy()

        df.attrs["x_key"] = x_key
        df.attrs["reward_key"] = r.reward_key
        per_run_df[r.run_path] = df

        # 导出 csv
        safe_label = "".join(ch if ch.isalnum() or ch in "-_() " else "_" for ch in r.label).strip()
        csv_path = out_dir / f"{r.group}_{safe_label}.csv"
        df.to_csv(csv_path, index=False)

    # 画图
    import matplotlib
    matplotlib.use('Agg')  # 使用Agg后端以支持高质量PDF输出
    import matplotlib.pyplot as plt  # noqa: WPS433
    from matplotlib.ticker import FuncFormatter  # noqa: WPS433

    plt.rcParams["pdf.fonttype"] = 42
    plt.rcParams["ps.fonttype"] = 42
    plt.rcParams["pdf.use14corefonts"] = False
    plt.rcParams["pdf.fonttype"] = 42
    plt.rcParams["ps.fonttype"] = 42
    plt.rcParams["pdf.use14corefonts"] = False
    # 中文字体做 fallback，避免标题/注释出现方块（macOS 优先 PingFang）
    plt.rcParams["font.sans-serif"] = [
        "DejaVu Sans",
    ]
    plt.rcParams["axes.unicode_minus"] = False
    plt.rcParams["figure.facecolor"] = "white"
    plt.rcParams["axes.facecolor"] = "#FAFAFA"
    plt.rcParams["savefig.facecolor"] = "white"
    plt.rcParams["font.family"] = "sans-serif"
    plt.rcParams["axes.linewidth"] = 1.8

    # 颜色配置（与 plot_length_ngram.py 保持一致）
    COLOR_BASELINE = '#00468B'  # 深蓝色 (baseline)
    COLOR_OURS = '#AE1029'      # 绯红色 (ours)
    
    # 字体大小（调大一些）
    FONT_TITLE = 32
    FONT_LABEL = 24
    FONT_TICK = 20
    FONT_LEGEND = 28

    # 布局：1行4列
    # qwen3-4b 的长度 | qwen3-4b 的 avg_score | octothinker 的长度 | octothinker 的 avg_score
    groups = ["组1", "组2"]
    group_labels = ["Qwen3-4B", "Llama-OctoThinker-3B"]
    metrics = [
        ("response_length/mean", r"$L$"),  # 长度
        ("avg_score/16384/four_sets", "ACC"),  # avg_score
    ]

    fig, axes = plt.subplots(
        nrows=1,
        ncols=4,
        figsize=(30.0, 10.0),  # 与 plot_length_ngram.py 类似的宽高比（1行4列）
        sharex=False,
        sharey=False,
        squeeze=False,
    )
    axes = axes[0]  # 展平为 1 维数组

    # 构建子图顺序：qwen3-4b 的长度, qwen3-4b 的 avg_score, octothinker 的长度, octothinker 的 avg_score
    plot_configs = []
    for gi, (g, group_label) in enumerate(zip(groups, group_labels)):
        for mi, (metric_key, ylabel) in enumerate(metrics):
            plot_configs.append((g, group_label, metric_key, ylabel, gi, mi))
    
    for col_idx, (g, group_label, metric_key, ylabel, gi, mi) in enumerate(plot_configs):
        ax = axes[col_idx]
        runs_in_group = [r for r in DEFAULT_RUNS if r.group == g]
        
        # 统一样式设置（与 plot_length_ngram.py 一致）
        ax.set_facecolor('#FAFAFA')
        for side in ['left', 'bottom', 'top', 'right']:
            ax.spines[side].set_linewidth(1.8)
            ax.spines[side].set_color('#666666')
            ax.spines[side].set_visible(True)
        ax.tick_params(axis='both', labelcolor='#333333', labelsize=FONT_TICK, length=6, width=1.5)
        ax.grid(True, axis='both', alpha=0.3, color='#CCCCCC', linewidth=0.8, linestyle='-', zorder=0)
        ax.set_axisbelow(True)
        ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=6, integer=True))
        ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))
        
        # 绘制曲线
        for r in runs_in_group:
                df = per_run_df[r.run_path]
                x_key = df.attrs.get("x_key", "_step")
                y_key = metric_key
                if y_key not in df.columns or x_key not in df.columns:
                    continue
                # 关键：很多指标（尤其是 avg_score）只在评估 step 上有值，其它 step 是 NaN。
                # matplot 遇到 NaN 会断线；这里先 dropna，只把"有值的点"连起来。
                df_xy = df[[x_key, y_key]].copy()
                df_xy = _safe_numeric(df_xy, [x_key, y_key]).dropna(subset=[x_key, y_key]).sort_values(x_key)
                if df_xy.empty:
                    continue
                y = maybe_smooth(df_xy[y_key], window=args.smooth)
                
                # 选择颜色：baseline 用蓝色，ours 用红色
                color = COLOR_BASELINE if "(Ours)" not in r.label else COLOR_OURS
                
                # 添加阴影带（与 plot_length_ngram.py 一致）
                y_array = np.array(y)
                window = min(5, len(y) // 10) if len(y) > 10 else 1
                if window > 1:
                    std_approx = np.convolve(
                        np.abs(np.diff(np.concatenate([[y_array[0]], y_array]))),
                        np.ones(window) / window,
                        mode='same'
                    ) * 1.2
                    ax.fill_between(
                        df_xy[x_key],
                        y_array - std_approx,
                        y_array + std_approx,
                        color=color,
                        alpha=0.2,
                        zorder=1,
                    )
                
                ax.plot(
                    df_xy[x_key],
                    y,
                    label=r.label,
                    linewidth=3.0,  # 与 plot_length_ngram.py 一致
                    color=color,
                    linestyle='-',
                    alpha=1.0,
                    zorder=4,
                )

        # 标题和标签
        # 所有子图都显示 metric 标题
        ax.set_title(ylabel, fontsize=FONT_TITLE, fontweight='bold', pad=20)
        # 所有子图都显示 x 轴标签
        ax.set_xlabel("Training Step", fontsize=FONT_LABEL, fontweight='bold', labelpad=10)
        # 所有子图都显示 y 轴标签和刻度
        # ax.set_ylabel(ylabel, fontsize=FONT_LABEL, fontweight='bold', labelpad=10)
        
        # 为 response_length 的 y 轴应用 k 格式
        if metric_key == "response_length/mean":
            ax.yaxis.set_major_formatter(FuncFormatter(thousands_formatter))

    # 添加组标题（qwen3-4b 和 octothinker）
    def _add_algorithm_group_titles(fig, axes_row):
        """在 1x4 的轴阵列上方加 'qwen3-4b'/'octothinker' 组标题（每组两列，算法只写一次）。"""
        p0, p1, p2, p3 = [ax.get_position() for ax in axes_row]
        x_qwen = (p0.x0 + p1.x1) / 2.0
        x_octo = (p2.x0 + p3.x1) / 2.0
        # y 放在 axes 上方更高一些
        y = max(p0.y1, p1.y1, p2.y1, p3.y1) + 0.075
        fig.text(x_qwen, y, 'Qwen3-4B', ha='center', va='bottom', fontsize=FONT_TITLE, fontweight='bold')
        fig.text(x_octo, y, 'Llama-OctoThinker-3B', ha='center', va='bottom', fontsize=FONT_TITLE, fontweight='bold')

    def _add_vertical_group_separator(fig, axes_row):
        """在第 2 列与第 3 列之间加垂直虚线分隔 qwen3-4b/octothinker 两组。"""
        p0, p1, p2, p3 = [ax.get_position() for ax in axes_row]
        x_qwen = (p0.x0 + p1.x1) / 2.0
        x_octo = (p2.x0 + p3.x1) / 2.0
        x = (x_qwen + x_octo) / 2.0 - 0.01
        y0 = min(p0.y0, p1.y0, p2.y0, p3.y0) - 0.08
        y1 = max(p0.y1, p1.y1, p2.y1, p3.y1) + 0.08
        line = plt.Line2D(
            [x, x],
            [y0, y1],
            transform=fig.transFigure,
            linestyle='--',
            linewidth=5,
            color='#666666',
            alpha=1.0,
            zorder=10,
        )
        fig.add_artist(line)

    # 全局图例（与 plot_length_ngram.py 一致）
    handles, labels = axes[0].get_legend_handles_labels()
    if handles:
        legend = fig.legend(
            handles,
            labels,
            loc='lower center',
            bbox_to_anchor=(0.5, -0.0),
            ncol=2,
            frameon=True,
            framealpha=0.95,
            edgecolor='#888888',
            fancybox=True,
            shadow=False,
            prop={'weight': 'bold', 'size': FONT_LEGEND},
        )
        legend.get_frame().set_linewidth(1.5)

    # 调整布局（与 plot_length_ngram.py 类似的间距）
    plt.subplots_adjust(wspace=0.28, top=0.72, bottom=0.22, left=0.05, right=0.99)
    _add_algorithm_group_titles(fig, axes)
    _add_vertical_group_separator(fig, axes)

    # 保存为 PDF（与 plot_length_ngram.py 一致）
    pdf_path = out_dir / "wandb_training_curves.pdf"
    fig.savefig(
        pdf_path,
        format='pdf',
        dpi=600,  # 与 plot_length_ngram.py 一致
        bbox_inches='tight',
        metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'},
    )
    
    # 同时保存 PNG（方便预览）
    png_path = out_dir / "wandb_training_curves.png"
    fig.savefig(png_path, dpi=200, bbox_inches='tight')
    
    plt.close(fig)

    print(f"\n完成：")
    print(f"- PDF：{pdf_path}")
    print(f"- PNG：{png_path}")
    print(f"- CSV：{out_dir}/*.csv")


if __name__ == "__main__":
    main()
