import glob
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

# 输出更像论文图（与本仓库其它 plot 脚本一致）
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
plt.rcParams["pdf.use14corefonts"] = False
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["axes.linewidth"] = 1.8

# ========== 固定输入路径 ==========
INPUT_DIR = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/baseline-gspo-dapo-math-minibsz32/valid"
TOKENIZER_PATH = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-4B-Base"
CACHE_PATHS = [
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/results/qwen3-4b-base.npz",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/results/grpo-step430.npz",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/results/gspo-step500.npz",
]
CACHE_LABELS = ["Qwen3-4B-Base", "GRPO", "GSPO"]
OUTPUT_PATH = "plots/motivation.pdf"
# =================================


@dataclass
class CacheBlob:
    indices: np.ndarray  # (K,)
    mean_eos_prob: np.ndarray  # (K,)
    counts: np.ndarray  # (K,)
    meta: Dict[str, Any]


def _parse_meta(meta_arr: Any) -> Dict[str, Any]:
    """解析 npz 中的 meta 字段"""
    try:
        meta_raw = meta_arr.item() if getattr(meta_arr, "shape", None) == () else str(meta_arr)
    except Exception:
        meta_raw = meta_arr

    if isinstance(meta_raw, (bytes, bytearray)):
        try:
            meta_raw = meta_raw.decode("utf-8")
        except Exception:
            meta_raw = str(meta_raw)

    if isinstance(meta_raw, str):
        try:
            parsed = json.loads(meta_raw)
            return parsed if isinstance(parsed, dict) else {}
        except Exception:
            return {}
    return {}


def load_cache(cache_path: str) -> CacheBlob:
    z = np.load(cache_path, allow_pickle=True)
    meta = _parse_meta(z.get("meta"))
    return CacheBlob(
        indices=z["indices"],
        mean_eos_prob=z["mean_eos_prob"],
        counts=z["counts"],
        meta=meta,
    )


def _iter_jsonl_outputs(path: str) -> Iterable[str]:
    with open(path, "r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                item = json.loads(line)
            except Exception as e:
                print(f"警告: JSONL 解析失败，跳过 {path}:{line_no}，错误: {e}")
                continue
            if not isinstance(item, dict):
                continue
            out = item.get("output")
            if isinstance(out, str) and out.strip():
                yield out


def collect_outputs(input_dir: str) -> Tuple[List[str], List[str]]:
    """读取 input_dir 下所有 *.jsonl，抽取每条样本的 output 文本。"""
    texts: List[str] = []
    sources: List[str] = []
    for p in sorted(glob.glob(os.path.join(input_dir, "*.jsonl"))):
        if not os.path.isfile(p):
            continue
        base = os.path.basename(p)
        count_before = len(texts)
        for out in _iter_jsonl_outputs(p):
            texts.append(out)
            sources.append(base)
        print(f"读取 {base}: 新增 {len(texts) - count_before} 条样本")

    return texts, sources


def _distinct_ngram_count(token_ids: List[int], n: int) -> int:
    if n <= 0:
        return 0
    if len(token_ids) < n:
        return 0
    s = set()
    for i in range(0, len(token_ids) - n + 1):
        s.add(tuple(token_ids[i : i + n]))
    return len(s)


def compute_length_and_distinct_ngrams(
    texts: List[str],
    tokenizer,
    n: int = 10,
    batch_size: int = 128,
) -> Tuple[List[int], List[int]]:
    lengths: List[int] = []
    distinct_counts: List[int] = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        enc = tokenizer(batch, padding=False, truncation=False, add_special_tokens=False)
        ids_list = enc["input_ids"]
        for ids in ids_list:
            L = len(ids)
            lengths.append(L)
            distinct_counts.append(_distinct_ngram_count(ids, n=n))

    return lengths, distinct_counts


def plot_combined_figure(
    lengths: List[int],
    distinct_counts: List[int],
    cache_blobs: List[CacheBlob],
    cache_labels: List[str],
    output_path: str,
    n: int = 10,
    max_points: Optional[int] = None,
    seed: int = 42,
) -> None:
    """绘制合并的图：左侧散点图，右侧EOS概率曲线"""
    BG = "#FAFAFA"
    SPINE_COLOR = "#666666"
    GRID_COLOR = "#CCCCCC"
    TICK_COLOR = "#333333"
    COLOR_POINTS = "#00468B"  # 深蓝
    COLOR_REF = "#AE1029"  # 绯红
    COLORS = ["#00468B", "#9B59B6", "#AE1029", "#FF7F0E", "#2CA02C"]  # EOS曲线颜色

    # 创建 1x2 的子图布局
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 10))

    # ========== 左侧：散点图 ==========
    MAX_LEN = 8000.0
    x = np.array(lengths, dtype=np.float32)
    y = np.array(distinct_counts, dtype=np.float32)
    mask = x <= MAX_LEN
    x = x[mask]
    y = y[mask]

    if max_points is not None and max_points > 0 and len(x) > max_points:
        rng = np.random.default_rng(seed)
        idx = rng.choice(len(x), size=max_points, replace=False)
        x = x[idx]
        y = y[idx]

    ax1.set_facecolor(BG)
    for side in ["left", "bottom", "top", "right"]:
        ax1.spines[side].set_linewidth(1.8)
        ax1.spines[side].set_color(SPINE_COLOR)
        ax1.spines[side].set_visible(True)
    ax1.grid(True, axis="both", alpha=0.3, color=GRID_COLOR, linewidth=0.8, linestyle="-", zorder=0)
    ax1.set_axisbelow(True)
    ax1.tick_params(axis="both", labelcolor=TICK_COLOR, labelsize=28, length=6, width=1.5)

    ax1.scatter(x, y, s=18, alpha=0.35, color=COLOR_POINTS, edgecolors="none", zorder=3)

    line_x = np.linspace(0.0, MAX_LEN, 200)
    ax1.plot(line_x, line_x, linestyle="--", linewidth=2.5, color=COLOR_REF, alpha=0.9, label=r"$y=x$", zorder=2)

    ax1.set_xlabel("Sequence Length (tokens)", fontsize=32, fontweight="bold", labelpad=10)
    ax1.set_ylabel(r"$C_{\mathrm{context}}(\tau)$", fontsize=32, fontweight="bold", labelpad=10)
    ax1.set_xlim(left=0, right=MAX_LEN)
    ax1.set_ylim(bottom=0)
    
    # 左侧图刻度改为k为单位
    def format_k(x, pos):
        """将数字转换为k单位"""
        if x == 0:
            return "0"
        return f"{x/1000:.0f}k"
    
    ax1.xaxis.set_major_formatter(ticker.FuncFormatter(format_k))
    ax1.yaxis.set_major_formatter(ticker.FuncFormatter(format_k))

    leg1 = ax1.legend(
        loc="upper left",
        frameon=True,
        framealpha=0.95,
        edgecolor="#888888",
        fancybox=True,
        prop={"size": 28},
    )
    leg1.get_frame().set_linewidth(1.2)

    # 添加子图标签 (a)
    ax1.text(
        0.5,
        -0.22,
        "(a)",
        transform=ax1.transAxes,
        fontsize=34,
        fontweight="bold",
        va="top",
        ha="center",
    )

    # ========== 右侧：EOS概率曲线 ==========
    ax2.set_facecolor(BG)
    for side in ["left", "bottom", "top", "right"]:
        ax2.spines[side].set_linewidth(1.8)
        ax2.spines[side].set_color(SPINE_COLOR)
        ax2.spines[side].set_visible(True)
    ax2.grid(True, axis="both", alpha=0.3, color=GRID_COLOR, linewidth=0.8, linestyle="-", zorder=0)
    ax2.set_axisbelow(True)
    ax2.tick_params(axis="both", labelcolor=TICK_COLOR, labelsize=28, length=6, width=1.5)

    metric = str(cache_blobs[0].meta.get("metric", "cumulative_product")) if cache_blobs else "cumulative_product"
    value_space = str(cache_blobs[0].meta.get("value_space", "prob")) if cache_blobs else "prob"
    exp_y = value_space == "ln"

    if metric == "stop_at_t":
        ylab_base = "P(stop at step=t) (mean)"
    else:
        ylab_base = "P(end by step=t)"

    all_ys: List[np.ndarray] = []
    for i, (b, lab) in enumerate(zip(cache_blobs, cache_labels)):
        xs = np.asarray(b.indices)
        ys = np.asarray(b.mean_eos_prob)
        if exp_y:
            ys = np.exp(ys)
        all_ys.append(ys)
        ax2.plot(xs, ys, linewidth=3.0, color=COLORS[i % len(COLORS)], label=f"{lab}", zorder=4)

    ax2.set_xlabel("Token index (step)", fontsize=32, fontweight="bold", labelpad=10)
    ax2.set_ylabel(ylab_base, fontsize=32, fontweight="bold", labelpad=10)
    
    # 设置y轴上限：基于数据最大值，增加顶部留白以完整展示legend和曲线
    if all_ys:
        y_max = max(np.nanmax(ys) for ys in all_ys if ys.size > 0)
        y_top = y_max * 1.4
        ax2.set_ylim(bottom=0.0, top=y_top)
    else:
        y_top = 1.4
        ax2.set_ylim(bottom=0.0, top=y_top)
    
    # 右侧图x轴刻度改为k为单位
    def format_k_x(x, pos):
        """将数字转换为k单位"""
        if x == 0:
            return "0"
        return f"{x/1000:.0f}k"
    
    ax2.xaxis.set_major_formatter(ticker.FuncFormatter(format_k_x))
    
    # 右侧图y轴只显示0.0到1.0的刻度，但保持上限为1.4
    def format_y_limited(x, pos):
        """只显示0.0到1.0的刻度标签"""
        if x > 1.0 + 1e-6:  # 添加小的容差避免浮点误差
            return ""
        if abs(x) < 1e-6:  # 0.0
            return "0.0"
        return f"{x:.1f}"
    
    ax2.yaxis.set_major_formatter(ticker.FuncFormatter(format_y_limited))
    # 设置y轴刻度位置：只显示0.0到1.0的刻度
    ax2.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

    # 添加子图标签 (b)
    ax2.text(
        0.5,
        -0.22,
        "(b)",
        transform=ax2.transAxes,
        fontsize=34,
        fontweight="bold",
        va="top",
        ha="center",
    )

    # 图例放在右边子图的左上角
    handles, labs = ax2.get_legend_handles_labels()
    if handles:
        legend = ax2.legend(
            handles,
            labs,
            loc="upper left",
            frameon=True,
            framealpha=0.95,
            edgecolor="#888888",
            fancybox=True,
            shadow=False,
            prop={"weight": "bold", "size": 20},
        )
        legend.get_frame().set_linewidth(1.5)

    # 保存
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    # 调整布局：底部留更多空间给(a)(b)标签，顶部留更多空间给右图legend
    plt.tight_layout(rect=(0.0, 0.15, 1.0, 0.95))
    plt.savefig(
        output_path,
        dpi=600,
        bbox_inches="tight",
        format="pdf",
        metadata={"Creator": "matplotlib", "Producer": "matplotlib"},
    )
    plt.close(fig)
    print(f"已保存合并图像: {output_path}")


def main():
    print("=" * 60)
    print("开始生成合并图...")
    print("=" * 60)

    # 加载 tokenizer
    try:
        from transformers import AutoTokenizer
    except Exception as e:
        raise RuntimeError(
            "导入 transformers 失败：当前环境可能存在 transformers/huggingface_hub 版本不兼容。\n"
            "建议在可用环境中运行，或修复依赖版本。\n"
            f"原始错误: {e}"
        )

    print(f"\n[1/4] 加载 tokenizer: {TOKENIZER_PATH}")
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)

    # 收集散点图数据
    print(f"\n[2/4] 收集 outputs: {INPUT_DIR}")
    texts, _sources = collect_outputs(INPUT_DIR)
    texts = [t for t in texts if isinstance(t, str) and t.strip()]
    print(f"总共收集到 {len(texts)} 条样本")

    if not texts:
        raise RuntimeError("未从目录中抽取到任何 output 文本，请检查 jsonl 是否含 output 字段。")

    print("\n[3/4] 计算 token length 与 distinct n-gram...")
    lengths, distinct_counts = compute_length_and_distinct_ngrams(
        texts=texts,
        tokenizer=tokenizer,
        n=10,
        batch_size=128,
    )

    # 加载 EOS 概率缓存
    print(f"\n[4/4] 加载 EOS 概率缓存...")
    cache_blobs = []
    for cache_path in CACHE_PATHS:
        if not os.path.exists(cache_path):
            print(f"警告: 缓存文件不存在，跳过: {cache_path}")
            continue
        blob = load_cache(cache_path)
        cache_blobs.append(blob)
        print(f"  已加载: {os.path.basename(cache_path)}")

    if len(cache_blobs) != len(CACHE_LABELS):
        raise ValueError(f"缓存文件数量 ({len(cache_blobs)}) 与标签数量 ({len(CACHE_LABELS)}) 不一致")

    # 绘制合并图
    print("\n开始绘图...")
    plot_combined_figure(
        lengths=lengths,
        distinct_counts=distinct_counts,
        cache_blobs=cache_blobs,
        cache_labels=CACHE_LABELS,
        output_path=OUTPUT_PATH,
        n=10,
        max_points=None,
        seed=42,
    )

    print("\n" + "=" * 60)
    print("完成！")
    print("=" * 60)


if __name__ == "__main__":
    main()
