import argparse
import glob
import json
import os
from typing import Iterable, List, Optional, Tuple

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

# 输出更像论文图（与本仓库其它 plot 脚本一致）
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
plt.rcParams["pdf.use14corefonts"] = False
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False


def _iter_jsonl_outputs(path: str) -> Iterable[str]:
    with open(path, "r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                item = json.loads(line)
            except Exception as e:
                print(f"警告: JSONL 解析失败，跳过 {path}:{line_no}，错误: {e}")
                continue
            if not isinstance(item, dict):
                continue
            out = item.get("output")
            if isinstance(out, str) and out.strip():
                yield out


def collect_outputs(input_dir: str) -> Tuple[List[str], List[str]]:
    """读取 input_dir 下所有 *.jsonl，抽取每条样本的 output 文本。"""
    texts: List[str] = []
    sources: List[str] = []
    for p in sorted(glob.glob(os.path.join(input_dir, "*.jsonl"))):
        if not os.path.isfile(p):
            continue
        base = os.path.basename(p)
        count_before = len(texts)
        for out in _iter_jsonl_outputs(p):
            texts.append(out)
            sources.append(base)
        print(f"读取 {base}: 新增 {len(texts) - count_before} 条样本")

    return texts, sources


def _distinct_ngram_count(token_ids: List[int], n: int) -> int:
    if n <= 0:
        return 0
    if len(token_ids) < n:
        return 0
    s = set()
    # 用 tuple 做 hash
    for i in range(0, len(token_ids) - n + 1):
        s.add(tuple(token_ids[i : i + n]))
    return len(s)


def compute_length_and_distinct_ngrams(
    texts: List[str],
    tokenizer,
    n: int = 10,
    batch_size: int = 128,
) -> Tuple[List[int], List[int]]:
    lengths: List[int] = []
    distinct_counts: List[int] = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        enc = tokenizer(batch, padding=False, truncation=False, add_special_tokens=False)
        ids_list = enc["input_ids"]
        for ids in ids_list:
            L = len(ids)
            lengths.append(L)
            distinct_counts.append(_distinct_ngram_count(ids, n=n))

    return lengths, distinct_counts


def plot_coverage_length_scatter(
    lengths: List[int],
    distinct_counts: List[int],
    output_path: str,
    n: int = 10,
    title: Optional[str] = None,
    max_points: Optional[int] = None,
    seed: int = 42,
) -> None:
    if not lengths or not distinct_counts:
        raise ValueError("没有可用数据点（lengths/distinct_counts 为空）")
    if len(lengths) != len(distinct_counts):
        raise ValueError("lengths 与 distinct_counts 长度不一致")

    # 最大长度限制：避免极长样本把横轴拉得过长影响可视化
    MAX_LEN = 8000.0
    x = np.array(lengths, dtype=np.float32)
    y = np.array(distinct_counts, dtype=np.float32)
    # 不做 clip（会在 MAX_LEN 处堆出一条竖线），而是直接丢弃超长点
    mask = x <= MAX_LEN
    x = x[mask]
    y = y[mask]

    # 可选下采样，避免点太多导致 PDF 体积巨大
    if max_points is not None and max_points > 0 and len(x) > max_points:
        rng = np.random.default_rng(seed)
        idx = rng.choice(len(x), size=max_points, replace=False)
        x = x[idx]
        y = y[idx]

    BG = "#FAFAFA"
    COLOR_POINTS = "#00468B"  # 深蓝
    COLOR_REF = "#AE1029"  # 绯红

    fig, ax = plt.subplots(1, 1, figsize=(9, 7))
    ax.set_facecolor(BG)
    for side in ["left", "bottom", "top", "right"]:
        ax.spines[side].set_linewidth(1.8)
        ax.spines[side].set_color("#666666")
        ax.spines[side].set_visible(True)
    ax.grid(True, axis="both", alpha=0.3, color="#CCCCCC", linewidth=0.8, linestyle="-")
    ax.set_axisbelow(True)
    ax.tick_params(axis="both", labelcolor="#333333", labelsize=20, length=6, width=1.5)

    ax.scatter(x, y, s=18, alpha=0.35, color=COLOR_POINTS, edgecolors="none", zorder=3)

    # 辅助线：y = x（理论上界/物理天花板的视觉提示）
    # 参考线只画到 MAX_LEN（x 轴右端）
    line_x = np.linspace(0.0, MAX_LEN, 200)
    ax.plot(line_x, line_x, linestyle="--", linewidth=2.5, color=COLOR_REF, alpha=0.9, label=r"$y=x$", zorder=2)

    ax.set_xlabel("Sequence Length (tokens)", fontsize=26, fontweight="bold", labelpad=10)
    ax.set_ylabel(r"$C_{\mathrm{context}}(\tau)$", fontsize=26, fontweight="bold", labelpad=10)
    ax.set_xlim(left=0, right=MAX_LEN)
    ax.set_ylim(bottom=0)

    # if title is None:
    #     title = f"Distinct {n}-gram vs Length"
    # ax.set_title(title, fontsize=20, fontweight="bold", pad=16)

    # 图例（小而清晰）
    leg = ax.legend(
        loc="upper left",
        frameon=True,
        framealpha=0.95,
        edgecolor="#888888",
        fancybox=True,
        prop={"size": 22},
    )
    leg.get_frame().set_linewidth(1.2)

    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    plt.tight_layout()
    plt.savefig(output_path, dpi=600, bbox_inches="tight", metadata={"Creator": "matplotlib", "Producer": "matplotlib"})
    plt.close(fig)
    print(f"已保存: {output_path}  (N={len(x)})")


def main():
    parser = argparse.ArgumentParser(description="Plot coverage-length scatter plot from a valid directory")
    parser.add_argument(
        "--input-dir",
        type=str,
        required=True,
        help="valid 目录路径（包含 *.jsonl，每行 item 必须含 output 字段）",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        default="/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-4B-Base",
        help="Tokenizer/model 路径（用于计算 token length 与 n-gram）",
    )
    parser.add_argument("--ngram", type=int, default=10, help="n-gram 大小，默认 10")
    parser.add_argument("--batch-size", type=int, default=128, help="tokenize batch size")
    parser.add_argument("--max-points", type=int, default=None, help="最多绘制点数（超出则随机下采样）")
    parser.add_argument("--seed", type=int, default=42, help="下采样随机种子")
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="输出图路径（.pdf/.png）。默认输出到 input-dir 同级: coverage_length_scatter.pdf",
    )
    parser.add_argument("--title", type=str, default=None, help="自定义标题")
    args = parser.parse_args()

    if not os.path.isdir(args.input_dir):
        raise FileNotFoundError(f"目录不存在: {args.input_dir}")

    output_path = args.output
    if output_path is None:
        parent = os.path.dirname(args.input_dir.rstrip("/"))
        if not parent:
            parent = "."
        output_path = os.path.join(parent, "coverage_length_scatter.pdf")

    try:
        from transformers import AutoTokenizer  # 延迟导入：确保 --help 不受环境影响
    except Exception as e:
        raise RuntimeError(
            "导入 transformers 失败：当前环境可能存在 transformers/huggingface_hub 版本不兼容。\n"
            "建议在可用环境中运行，或修复依赖版本。\n"
            f"原始错误: {e}"
        )

    print("加载 tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

    print("收集所有 outputs...")
    texts, _sources = collect_outputs(args.input_dir)
    texts = [t for t in texts if isinstance(t, str) and t.strip()]
    print(f"总共收集到 {len(texts)} 条样本")

    if not texts:
        raise RuntimeError("未从目录中抽取到任何 output 文本，请检查 jsonl 是否含 output 字段。")

    print("计算 token length 与 distinct n-gram...")
    lengths, distinct_counts = compute_length_and_distinct_ngrams(
        texts=texts,
        tokenizer=tokenizer,
        n=int(args.ngram),
        batch_size=int(args.batch_size),
    )

    print("绘图...")
    plot_coverage_length_scatter(
        lengths=lengths,
        distinct_counts=distinct_counts,
        output_path=output_path,
        n=int(args.ngram),
        title=args.title,
        max_points=args.max_points,
        seed=int(args.seed),
    )


if __name__ == "__main__":
    main()

