
from __future__ import annotations

import argparse
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import pandas as pd
import numpy as np


@dataclass(frozen=True)
class RunSpec:
    group: str
    label: str
    run_path: str  # entity/project/run_id
    reward_key: str


# 两个 beta 值的 run
DEFAULT_RUNS: List[RunSpec] = [
    RunSpec(
        group="beta",
        label="beta=0.3",
        run_path="astrid_tuning_llm/verl-qwen3-4b-oct/96i8jqw8",
        reward_key="critic/acc/mean",
    ),
    RunSpec(
        group="beta",
        label="beta=0.6",
        run_path="astrid_tuning_llm/verl-qwen3-4b-oct/fn0l0hij",
        reward_key="critic/acc/mean",
    ),
]

DEFAULT_METRICS = [
    "critic/acc/mean",
    "response_length/mean",
    "actor/entropy",
    "avg_score/16384/four_sets",
]


def _ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


def _pick_x_key(df: pd.DataFrame, prefer: Optional[str] = None) -> str:
    # wandb 常见 step 字段（不同代码库可能不同）
    candidates = []
    if prefer:
        candidates.append(prefer)
    candidates += [
        "trainer/global_step",
        "global_step",
        "_step",
        "step",
        "train/global_step",
    ]
    for k in candidates:
        if k in df.columns:
            return k
    raise ValueError(f"找不到可用的 step 列，现有列：{list(df.columns)[:50]}")


def _safe_numeric(df: pd.DataFrame, cols: Iterable[str]) -> pd.DataFrame:
    for c in cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    return df


def fetch_run_history(
    api,  # wandb.Api
    run_path: str,
    keys: List[str],
) -> pd.DataFrame:
    """
    用 scan_history 尽量全量拉取（history() 可能被采样/截断）。
    """
    run = api.run(run_path)
    rows = []
    # 重要：如果 scan_history(keys=...) 里包含不存在的 key，某些情况下会返回空。
    # 这里改为不传 keys，全量扫描后再做列筛选，保证稳健。
    for row in run.scan_history(page_size=1000):
        if row is None:
            continue
        rows.append(dict(row))
    if not rows:
        return pd.DataFrame(columns=keys)
    df = pd.DataFrame(rows)
    keep_cols = [c for c in keys if c in df.columns]
    if keep_cols:
        return df[keep_cols].copy()
    return df


def maybe_smooth(series: pd.Series, window: int) -> pd.Series:
    if window <= 1:
        return series
    return series.rolling(window=window, min_periods=max(1, window // 3)).mean()


def thousands_formatter(x, pos):
    """将数字格式化为k格式，500以上都用k"""
    return f'{x/1000:.1f}k'


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--out-dir",
        type=str,
        default="_wandb_plots",
        help="输出目录（会创建一个时间戳子目录）",
    )
    parser.add_argument(
        "--x-key",
        type=str,
        default="",
        help="强制使用的 x 轴列名（默认自动探测，如 trainer/global_step 或 _step）",
    )
    parser.add_argument(
        "--smooth",
        type=int,
        default=5,
        help="滑动平均窗口（1 表示不平滑）",
    )
    parser.add_argument(
        "--max-points",
        type=int,
        default=0,
        help="每条曲线最多保留多少点（0 表示不截断；>0 会等距采样，便于画图更快）",
    )
    args = parser.parse_args()

    # 延迟 import，避免在没有 wandb 的环境里直接 ImportError
    import wandb  # noqa: WPS433

    # 基础检查：WANDB_API_KEY 不一定必须（本机已 wandb login 也行）
    if not os.environ.get("WANDB_API_KEY"):
        print("提示：未检测到环境变量 WANDB_API_KEY，如果你已 `wandb login` 可忽略。")

    out_root = Path(args.out_dir).expanduser().resolve()
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_dir = out_root / ts
    _ensure_dir(out_dir)

    # 需要拉取的 keys：所有指标 + 常见 step 字段
    keys = list(DEFAULT_METRICS) + [
        "_step",
        "trainer/global_step",
        "global_step",
        "step",
        "train/global_step",
    ]
    keys = sorted(set(keys))

    print(f"连接 wandb API，准备拉取 {len(DEFAULT_RUNS)} 个 runs 的曲线…")
    api = wandb.Api()

    per_run_df: Dict[str, pd.DataFrame] = {}
    for r in DEFAULT_RUNS:
        print(f"- 拉取 {r.label}: {r.run_path}")
        df = fetch_run_history(api=api, run_path=r.run_path, keys=keys)
        # 保留必要列（避免 scan_history 返回一些空列）
        keep_cols = [c for c in keys if c in df.columns]
        df = df[keep_cols].copy()
        x_key = _pick_x_key(df, prefer=args.x_key if args.x_key else None)
        metric_cols = [c for c in DEFAULT_METRICS if c in df.columns]
        df = _safe_numeric(df, [x_key, *metric_cols])
        df = df.dropna(subset=[x_key]).sort_values(x_key)
        # 限制最大 step 为 300
        df = df[df[x_key] <= 600].copy()
        if args.max_points and len(df) > args.max_points:
            # 等距采样（保首尾）
            idx = np.linspace(0, len(df) - 1, args.max_points).astype(int)
            df = df.iloc[idx].copy()

        df.attrs["x_key"] = x_key
        df.attrs["reward_key"] = r.reward_key
        per_run_df[r.run_path] = df

        # 导出 csv
        safe_label = "".join(
            ch if ch.isalnum() or ch in "-_() " else "_" for ch in r.label).strip()
        csv_path = out_dir / f"{safe_label}.csv"
        df.to_csv(csv_path, index=False)

    # 画图
    import matplotlib
    matplotlib.use('Agg')  # 使用Agg后端以支持高质量PDF输出
    import matplotlib.pyplot as plt  # noqa: WPS433
    from matplotlib.ticker import FuncFormatter  # noqa: WPS433

    plt.rcParams["pdf.fonttype"] = 42
    plt.rcParams["ps.fonttype"] = 42
    plt.rcParams["pdf.use14corefonts"] = False
    plt.rcParams["pdf.fonttype"] = 42
    plt.rcParams["ps.fonttype"] = 42
    plt.rcParams["pdf.use14corefonts"] = False
    # 中文字体做 fallback，避免标题/注释出现方块（macOS 优先 PingFang）
    plt.rcParams["font.sans-serif"] = [
        "DejaVu Sans",
    ]
    plt.rcParams["axes.unicode_minus"] = False
    plt.rcParams["figure.facecolor"] = "white"
    plt.rcParams["axes.facecolor"] = "#FAFAFA"
    plt.rcParams["savefig.facecolor"] = "white"
    plt.rcParams["font.family"] = "sans-serif"
    plt.rcParams["axes.linewidth"] = 1.8

    # 颜色配置：两个 beta 值用不同颜色
    COLOR_BETA_03 = '#00468B'  # 深蓝色 (beta=0.3)
    COLOR_BETA_06 = '#AE1029'  # 绯红色 (beta=0.6)

    # 字体大小（调大一些）
    FONT_TITLE = 32
    FONT_LABEL = 24
    FONT_TICK = 20
    FONT_LEGEND = 28

    # 布局：1行4列，每个指标一个子图
    metrics = [
        ("critic/acc/mean", "Critic Acc"),
        ("response_length/mean", r"$L$"),
        ("actor/entropy", "Entropy"),
        ("avg_score/16384/four_sets", "ACC"),
    ]

    fig, axes = plt.subplots(
        nrows=1,
        ncols=4,
        figsize=(30.0, 10.0),  # 1行4列
        sharex=False,
        sharey=False,
        squeeze=False,
    )
    axes = axes[0]  # 展平为 1 维数组

    # 绘制每个指标的曲线
    for col_idx, (metric_key, ylabel) in enumerate(metrics):
        ax = axes[col_idx]

        # 统一样式设置
        ax.set_facecolor('#FAFAFA')
        for side in ['left', 'bottom', 'top', 'right']:
            ax.spines[side].set_linewidth(1.8)
            ax.spines[side].set_color('#666666')
            ax.spines[side].set_visible(True)
        ax.tick_params(axis='both', labelcolor='#333333',
                       labelsize=FONT_TICK, length=6, width=1.5)
        ax.grid(True, axis='both', alpha=0.3, color='#CCCCCC',
                linewidth=0.8, linestyle='-', zorder=0)
        ax.set_axisbelow(True)
        ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=6, integer=True))
        ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))

        # 绘制两个 run 的曲线
        for r in DEFAULT_RUNS:
            df = per_run_df[r.run_path]
            x_key = df.attrs.get("x_key", "_step")
            y_key = metric_key
            if y_key not in df.columns or x_key not in df.columns:
                continue
            # 关键：很多指标（尤其是 avg_score）只在评估 step 上有值，其它 step 是 NaN。
            # matplot 遇到 NaN 会断线；这里先 dropna，只把"有值的点"连起来。
            df_xy = df[[x_key, y_key]].copy()
            df_xy = _safe_numeric(df_xy, [x_key, y_key]).dropna(
                subset=[x_key, y_key]).sort_values(x_key)
            if df_xy.empty:
                continue
            y = maybe_smooth(df_xy[y_key], window=args.smooth)

            # 选择颜色：beta=0.3 用蓝色，beta=0.6 用红色
            color = COLOR_BETA_03 if "0.3" in r.label else COLOR_BETA_06

            # 添加阴影带
            y_array = np.array(y)
            window = min(5, len(y) // 10) if len(y) > 10 else 1
            if window > 1:
                std_approx = np.convolve(
                    np.abs(np.diff(np.concatenate([[y_array[0]], y_array]))),
                    np.ones(window) / window,
                    mode='same'
                ) * 1.2
                ax.fill_between(
                    df_xy[x_key],
                    y_array - std_approx,
                    y_array + std_approx,
                    color=color,
                    alpha=0.2,
                    zorder=1,
                )

            ax.plot(
                df_xy[x_key],
                y,
                label=r.label,
                linewidth=3.0,
                color=color,
                linestyle='-',
                alpha=1.0,
                zorder=4,
            )

        # 标题和标签
        ax.set_title(ylabel, fontsize=FONT_TITLE, fontweight='bold', pad=20)
        ax.set_xlabel("Training Step", fontsize=FONT_LABEL,
                      fontweight='bold', labelpad=10)

        # 为 response_length 的 y 轴应用 k 格式
        if metric_key == "response_length/mean":
            ax.yaxis.set_major_formatter(FuncFormatter(thousands_formatter))

    # 全局图例
    handles, labels = axes[0].get_legend_handles_labels()
    if handles:
        legend = fig.legend(
            handles,
            labels,
            loc='lower center',
            bbox_to_anchor=(0.5, -0.0),
            ncol=2,
            frameon=True,
            framealpha=0.95,
            edgecolor='#888888',
            fancybox=True,
            shadow=False,
            prop={'weight': 'bold', 'size': FONT_LEGEND},
        )
        legend.get_frame().set_linewidth(1.5)

    # 调整布局（与 get_wandb_training.py 保持一致）
    plt.subplots_adjust(wspace=0.28, top=0.80,
                        bottom=0.22, left=0.05, right=0.99)

    # 保存为 PDF
    pdf_path = out_dir / "wandb_training_curves_beta.pdf"
    fig.savefig(
        pdf_path,
        format='pdf',
        dpi=600,
        bbox_inches='tight',
        metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'},
    )

    # 同时保存 PNG（方便预览）
    png_path = out_dir / "wandb_training_curves_beta.png"
    fig.savefig(png_path, dpi=200, bbox_inches='tight')

    plt.close(fig)

    print(f"\n完成：")
    print(f"- PDF：{pdf_path}")
    print(f"- PNG：{png_path}")
    print(f"- CSV：{out_dir}/*.csv")


if __name__ == "__main__":
    main()
