from matplotlib.ticker import FuncFormatter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
import argparse
import glob
import matplotlib
matplotlib.use('Agg')  # 使用Agg后端以支持高质量PDF输出

# 设置PDF输出参数
plt.rcParams['pdf.fonttype'] = 42  # TrueType字体，确保文字在PDF中可编辑
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['pdf.use14corefonts'] = False

# 实验配置
EXPERIMENTS = {
    "baseline-gspo-dapo-math-minibsz32": {
        "title": "GSPO",
        "wandb_runs": [
            "astrid_tuning_llm/verl-qwen3-4b-oct/e7mn9b4j",
            "astrid_tuning_llm/verl-qwen3-4b-oct/3cvtlkqw"
        ]
    },
    "skip-right-skip-limits10-gspo-dapo-math": {
        "title": "GSPO + LINE",
        "wandb_runs": [
            "astrid_tuning_llm/verl-qwen3-4b-oct/u2e2zp05"
        ]
    }
}

# 颜色配置
COLOR_BASELINE = '#00468B'  # 深蓝色
COLOR_OURS = '#9B59B6'      # 紫色


def _pick_first_existing_key(record: dict, candidates):
    """在给定候选 key 列表中，返回第一个存在于 record 的 key；都不存在则返回 None。"""
    for k in candidates:
        if k in record:
            return k
    return None


def _get_diversity_key_and_label(record0: dict, n: int, diversity_metric: str):
    """
    返回 (key, y_label, kind)，其中 kind ∈ {'count','ratio'}，用于决定 y 轴范围与格式化。
    """
    if diversity_metric == 'global_count':
        key = _pick_first_existing_key(
            record0,
            [
                f'step_global_distinct_{n}gram_count'
            ],
        )
        # r'$\mathbf{C_{\mathrm{global}}(\mathcal{T})}$'
        y_label = "Global Distinct N-gram Count"
        return key, y_label, 'count'

    if diversity_metric == 'global_ratio':
        key = _pick_first_existing_key(
            record0,
            [
                f'step_global_distinct_{n}gram_ratio'
            ],
        )
        y_label = r'$\mathbf{R_{\mathrm{global}}(\mathcal{T})}$'
        return key, y_label, 'ratio'

    raise ValueError(f"Unknown diversity_metric: {diversity_metric}")


def _load_and_merge_results(data_paths):
    """加载多份 JSON 并合并（假设结构一致：{exp_name: [records...] }）。"""
    merged: dict = {}
    for p in data_paths:
        if not os.path.exists(p):
            print(f"警告: 找不到数据文件 {p}，跳过")
            continue
        try:
            with open(p, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"加载 JSON 失败: {p}，错误: {e}")
            continue

        if not isinstance(data, dict):
            print(f"警告: {p} 的顶层不是 dict（期望 {{exp_name: [..]}}），跳过")
            continue

        # 如果同名 experiment 冲突，则后加载的文件覆盖前面的
        overlap = set(merged.keys()) & set(data.keys())
        if overlap:
            print(
                f"警告: {p} 与已有数据存在重复 experiment key，将覆盖: {sorted(list(overlap))}")
        merged.update(data)

    return merged


def _pick_x_key(df: pd.DataFrame, prefer: str = None) -> str:
    """自动识别 step 列"""
    candidates = []
    if prefer:
        candidates.append(prefer)
    candidates += [
        "trainer/global_step",
        "global_step",
        "_step",
        "step",
        "train/global_step",
    ]
    for k in candidates:
        if k in df.columns:
            return k
    raise ValueError(f"找不到可用的 step 列，现有列：{list(df.columns)[:50]}")


def fetch_run_history(api, run_path: str, keys: list) -> pd.DataFrame:
    """
    从 wandb 获取运行历史数据
    """
    try:
        # 尝试访问 run
        run = api.run(run_path)
        print(f"  ✓ 成功找到 run: {run.name} (状态: {run.state})")

        # 使用 scan_history 获取所有数据（不传 keys 参数，避免遗漏）
        rows = []
        total_scanned = 0
        for row in run.scan_history(page_size=1000):
            if row is None:
                continue
            rows.append(dict(row))
            total_scanned += 1
            if total_scanned % 1000 == 0:
                print(f"    已扫描 {total_scanned} 条记录...")

        if not rows:
            print(f"  ⚠ run 存在但没有历史数据")
            return pd.DataFrame(columns=keys)

        df = pd.DataFrame(rows)
        print(f"  ✓ 成功获取 {len(df)} 条原始记录")
        print(f"  可用列: {list(df.columns)[:20]}...")  # 只显示前20列

        # 检查是否有 actor/entropy 列
        if "actor/entropy" not in df.columns:
            print(f"  ⚠ 警告: 数据中没有 'actor/entropy' 列")
            print(
                f"  可用的 entropy 相关列: {[c for c in df.columns if 'entropy' in c.lower()]}")

        keep_cols = [c for c in keys if c in df.columns]
        if keep_cols:
            result_df = df[keep_cols].copy()
            print(f"  ✓ 返回 {len(result_df)} 条记录，包含列: {keep_cols}")
            return result_df
        else:
            print(f"  ⚠ 数据中没有找到所需的列: {keys}")
            # 返回所有数据，让调用者自己处理
            return df
    except Exception as e:
        error_msg = str(e)
        print(f"  ✗ 错误: {error_msg}")
        # 提供更详细的错误信息
        if "not found" in error_msg.lower() or "could not find" in error_msg.lower():
            print(f"    提示: Run ID 可能不正确，或者你没有访问权限")
            print(f"    请检查 wandb 网页上该 run 的完整路径是否正确")
        return pd.DataFrame()


def _safe_numeric(df: pd.DataFrame, cols: list) -> pd.DataFrame:
    """安全地将列转换为数值类型"""
    for c in cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    return df


def _compact_count_formatter_factory(max_abs_value: float):
    """
    根据量级自动选择 y 轴显示：
    - >= 1e6: 用 M
    - >= 1e3: 用 k
    - 其它: 用整数
    """
    max_abs_value = float(max_abs_value) if max_abs_value is not None else 0.0

    def _fmt(x, pos):
        try:
            x = float(x)
        except Exception:
            return str(x)

        if max_abs_value >= 1_000_000:
            return f'{x/1_000_000:.2f}M'
        if max_abs_value >= 1_000:
            return f'{x/1_000:.1f}k'
        return f'{int(round(x))}'

    return _fmt


def _ratio_percent_formatter_factory(ylim):
    """
    ratio（通常在 0~1）用百分比展示时，根据当前 y 轴范围自适应保留小数位
    """
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
        span_pct = abs((y1 - y0) * 100.0)
        max_pct = max(abs(y0), abs(y1)) * 100.0
    except Exception:
        span_pct, max_pct = 0.0, 0.0

    if max_pct < 1.0 or span_pct < 0.5:
        fmt = '{:.2f}'
    elif max_pct < 10.0 or span_pct < 3.0:
        fmt = '{:.1f}'
    else:
        fmt = '{:.0f}'

    def _fmt(x, pos):
        try:
            return fmt.format(float(x) * 100.0)
        except Exception:
            return str(x)

    return _fmt


def _add_top_headroom(ylim, frac: float = 0.08, cap_upper=None):
    """只给 y 轴上边界增加一点留白"""
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
    except Exception:
        return ylim
    span = max(y1 - y0, 1e-9)
    y1_new = y1 + span * float(frac)
    if cap_upper is not None:
        try:
            y1_new = min(float(cap_upper), y1_new)
        except Exception:
            pass
    return (y0, y1_new)


def plot_global_metrics_and_entropy(
    data_path,
    output_dir='plots',
    ngram_sizes=[10],
    dpi=600,
):
    """
    绘制 global count 和 entropy 两个图
    """
    # 兼容：既支持单个 json，也支持多个 json
    if isinstance(data_path, (list, tuple)):
        data_paths = list(data_path)
    else:
        data_paths = [data_path]

    # 加载 JSON 数据（global count 和 global ratio）
    all_results = _load_and_merge_results(data_paths)
    if not all_results:
        print("错误: 没有成功加载任何数据")
        return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 从 wandb 获取 entropy 数据
    import wandb

    if not os.environ.get("WANDB_API_KEY"):
        print("提示：未检测到环境变量 WANDB_API_KEY，如果你已 `wandb login` 可忽略。")

    api = wandb.Api()

    # 收集所有 wandb 数据
    wandb_data = {}
    for exp_key, exp_config in EXPERIMENTS.items():
        all_entropy_data = []
        all_steps = []

        for run_path in exp_config["wandb_runs"]:
            print(f"\n尝试获取 run: {run_path}")
            keys = ["actor/entropy", "_step", "trainer/global_step",
                    "global_step", "step", "train/global_step"]
            df = fetch_run_history(api, run_path, keys)

            if df.empty:
                print(f"警告: {run_path} 没有数据")
                continue

            x_key = _pick_x_key(df)
            print(f"  使用的 step 列: {x_key}")
            print(f"  原始数据列: {list(df.columns)}")
            print(f"  原始数据行数: {len(df)}")

            df = _safe_numeric(df, [x_key, "actor/entropy"])
            df_before_dropna = df.copy()
            df = df.dropna(subset=[x_key, "actor/entropy"]).sort_values(x_key)

            print(f"  去除 NaN 后行数: {len(df)}")
            if len(df) < len(df_before_dropna):
                print(
                    f"  警告: 丢失了 {len(df_before_dropna) - len(df)} 行数据（包含 NaN）")

            if not df.empty:
                entropy_values = df[["actor/entropy"]].values.flatten()
                step_values = df[x_key].values
                all_entropy_data.append(entropy_values)
                all_steps.append(step_values)
                print(
                    f"  ✓ 数据范围: step {step_values.min():.0f} - {step_values.max():.0f}, {len(step_values)} 个点")
                print(
                    f"  entropy 范围: {entropy_values.min():.4f} - {entropy_values.max():.4f}")
            else:
                print(f"  ✗ 数据为空，跳过此 run")

        # 合并多个 run 的数据（如果有多个）
        if all_steps:
            if len(all_steps) == 1:
                wandb_data[exp_key] = {
                    "steps": all_steps[0],
                    "entropy": all_entropy_data[0]
                }
            else:
                # 合并多个 run：正确处理时间范围不一致的情况
                # 1. 先处理每个 run 的数据（排序、去重）
                processed_runs = []
                for steps, entropy in zip(all_steps, all_entropy_data):
                    # 确保 steps 是单调递增的且去重
                    sort_idx = np.argsort(steps)
                    steps_sorted = steps[sort_idx]
                    entropy_sorted = entropy[sort_idx]

                    # 去重（保留最后一个值）
                    unique_mask = np.concatenate(
                        [[True], np.diff(steps_sorted) > 0])
                    steps_sorted = steps_sorted[unique_mask]
                    entropy_sorted = entropy_sorted[unique_mask]

                    processed_runs.append({
                        "steps": steps_sorted,
                        "entropy": entropy_sorted,
                        "min_step": steps_sorted.min(),
                        "max_step": steps_sorted.max()
                    })
                    print(
                        f"  Run {len(processed_runs)}: step {steps_sorted.min():.0f} - {steps_sorted.max():.0f}, {len(steps_sorted)} 个点")

                # 2. 检查 run 之间是否有重叠
                overlaps = []
                for i in range(len(processed_runs)):
                    for j in range(i+1, len(processed_runs)):
                        r1, r2 = processed_runs[i], processed_runs[j]
                        overlap_start = max(r1["min_step"], r2["min_step"])
                        overlap_end = min(r1["max_step"], r2["max_step"])
                        if overlap_start <= overlap_end:
                            overlaps.append((i, j, overlap_start, overlap_end))

                # 3. 创建统一的 step 序列
                # 如果 run 之间有重叠，需要合并重叠部分；如果没有重叠，直接连接
                if overlaps:
                    print(f"  检测到 run 之间有重叠，合并数据")
                    # 有重叠：创建覆盖所有 run 的完整 step 序列
                    all_step_values = np.concatenate(
                        [r["steps"] for r in processed_runs])
                    min_step = int(all_step_values.min())
                    max_step = int(all_step_values.max())

                    # 创建覆盖所有 run 的完整 step 序列
                    # 使用所有 run 的 steps 的并集，然后去重排序
                    all_unique_steps = np.unique(all_step_values)
                    all_unique_steps = np.sort(all_unique_steps)

                    print(
                        f"  创建完整 step 序列: step {min_step} - {max_step}, 共 {len(all_unique_steps)} 个点")
                    base_steps = all_unique_steps

                    # 对每个 run 进行插值到完整的 base_steps
                    merged_entropy = []
                    for run_data in processed_runs:
                        steps_sorted = run_data["steps"]
                        entropy_sorted = run_data["entropy"]
                        run_min = run_data["min_step"]
                        run_max = run_data["max_step"]

                        # 只对在 run 数据范围内的 step 进行插值
                        run_mask = (base_steps >= run_min) & (
                            base_steps <= run_max)
                        interp_entropy = np.full(len(base_steps), np.nan)

                        if np.any(run_mask):
                            interp_entropy[run_mask] = np.interp(
                                base_steps[run_mask],
                                steps_sorted,
                                entropy_sorted
                            )

                        merged_entropy.append(interp_entropy)

                    # 合并：重叠部分取平均值，非重叠部分使用对应值
                    final_entropy = np.full(len(base_steps), np.nan)
                    for interp_ent in merged_entropy:
                        valid = ~np.isnan(interp_ent)
                        if np.any(valid):
                            existing = ~np.isnan(final_entropy)
                            overlap = existing & valid
                            new_only = valid & ~existing

                            if np.any(overlap):
                                # 重叠部分：取平均值
                                final_entropy[overlap] = (
                                    final_entropy[overlap] + interp_ent[overlap]) / 2.0
                            if np.any(new_only):
                                # 非重叠部分：直接使用
                                final_entropy[new_only] = interp_ent[new_only]

                    # 移除 NaN 值（理论上不应该有，但为了安全）
                    valid_mask = ~np.isnan(final_entropy)
                    if np.any(valid_mask):
                        wandb_data[exp_key] = {
                            "steps": base_steps[valid_mask],
                            "entropy": final_entropy[valid_mask]
                        }
                    else:
                        print(f"  警告: 合并后没有有效数据")
                else:
                    # 没有重叠：直接连接所有 run 的数据
                    print(f"  检测到 run 之间没有重叠，直接连接数据")
                    all_merged_steps = []
                    all_merged_entropy = []

                    # 按 min_step 排序
                    sorted_runs = sorted(
                        processed_runs, key=lambda x: x["min_step"])

                    for run_data in sorted_runs:
                        all_merged_steps.extend(run_data["steps"])
                        all_merged_entropy.extend(run_data["entropy"])

                    # 转换为 numpy 数组并排序
                    all_merged_steps = np.array(all_merged_steps)
                    all_merged_entropy = np.array(all_merged_entropy)
                    sort_idx = np.argsort(all_merged_steps)

                    wandb_data[exp_key] = {
                        "steps": all_merged_steps[sort_idx],
                        "entropy": all_merged_entropy[sort_idx]
                    }

                if exp_key in wandb_data:
                    final_steps = wandb_data[exp_key]["steps"]
                    print(
                        f"  最终合并数据范围: step {final_steps.min():.0f} - {final_steps.max():.0f}, {len(final_steps)} 个点")

    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['axes.linewidth'] = 1.8

    # 字体大小
    FONT_TICK = 24
    FONT_XLABEL = 28
    FONT_YLABEL = 24
    FONT_TITLE = 32
    FONT_LEGEND = 30

    for n in ngram_sizes:
        # 创建 1行2列的子图
        fig, axes = plt.subplots(1, 2, figsize=(16.0, 7.0), squeeze=True)
        ax_count, ax_entropy = axes

        # 收集所有数据用于计算统一的 y 轴范围（截取到 600 步）
        all_counts = []
        all_entropies = []

        for exp_key, exp_config in EXPERIMENTS.items():
            if exp_key not in all_results:
                continue

            res = sorted(all_results[exp_key], key=lambda x: x['step'])

            # 获取 global count（截取到 600 步）
            count_key, _, _ = _get_diversity_key_and_label(
                res[0], n, 'global_count')

            if count_key:
                counts = [r.get(count_key, 0) for r in res if r['step'] <= 600]
                all_counts.extend(counts)

            # entropy（截取到 600 步）
            if exp_key in wandb_data:
                steps = wandb_data[exp_key]["steps"]
                entropy = wandb_data[exp_key]["entropy"]
                mask = steps <= 600
                all_entropies.extend(entropy[mask])

        # 计算 y 轴范围
        if all_counts:
            count_min, count_max = min(all_counts), max(all_counts)
            count_range = float(count_max - count_min)
            if abs(count_range) < 1e-9:
                base = max(abs(float(count_max)), 1.0)
                count_margin = base * 0.1
            else:
                count_margin = count_range * 0.1
            count_ylim = (max(0, count_min - count_margin),
                          count_max + count_margin)
            count_ylim = _add_top_headroom(
                count_ylim, frac=0.08, cap_upper=None)
        else:
            count_ylim = (0, 1000)

        if all_entropies:
            entropy_min, entropy_max = min(all_entropies), max(all_entropies)
            entropy_range = float(entropy_max - entropy_min)
            if abs(entropy_range) < 1e-9:
                base = max(abs(float(entropy_max)), 1.0)
                entropy_margin = base * 0.1
            else:
                entropy_margin = entropy_range * 0.1
            entropy_ylim = (max(0, entropy_min - entropy_margin),
                            entropy_max + entropy_margin)
            entropy_ylim = _add_top_headroom(
                entropy_ylim, frac=0.08, cap_upper=None)
        else:
            entropy_ylim = (0, 1.0)

        def _style_ax(ax):
            ax.set_facecolor('#FAFAFA')
            for side in ['left', 'bottom', 'top', 'right']:
                ax.spines[side].set_linewidth(1.8)
                ax.spines[side].set_color('#666666')
                ax.spines[side].set_visible(True)
            ax.tick_params(axis='both', labelcolor='#333333',
                           labelsize=FONT_TICK, length=6, width=1.5)
            ax.grid(True, axis='both', alpha=0.3, color='#CCCCCC',
                    linewidth=0.8, linestyle='-', zorder=0)
            ax.set_axisbelow(True)
            ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=6, integer=True))
            ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))

        # 绘制 global count
        _style_ax(ax_count)
        ax_count.set_title('Global Distinct N-gram Count',
                           fontsize=FONT_TITLE, fontweight='bold', pad=14)
        ax_count.set_xlabel(
            'Training Step', fontsize=FONT_XLABEL, fontweight='bold', labelpad=10)
        # ax_count.set_ylabel(r'$C_{\mathrm{global}}(\mathcal{T})$', fontsize=FONT_YLABEL, fontweight='bold', labelpad=10)
        ax_count.set_xlim(0, 610)  # 末尾留有空余
        ax_count.set_ylim(count_ylim)
        max_abs = max(abs(float(count_ylim[0])), abs(float(count_ylim[1])))
        ax_count.yaxis.set_major_formatter(FuncFormatter(
            _compact_count_formatter_factory(max_abs)))

        # 绘制 entropy
        _style_ax(ax_entropy)
        ax_entropy.set_title('Entropy', fontsize=FONT_TITLE,
                             fontweight='bold', pad=14)
        ax_entropy.set_xlabel(
            'Training Step', fontsize=FONT_XLABEL, fontweight='bold', labelpad=10)
        # ax_entropy.set_ylabel('Entropy', fontsize=FONT_YLABEL, fontweight='bold', labelpad=10)
        ax_entropy.set_xlim(0, 610)  # 末尾留有空余
        ax_entropy.set_ylim(entropy_ylim)

        # 绘制每个实验的曲线
        legend_handles = []
        for exp_key, exp_config in EXPERIMENTS.items():
            color = COLOR_BASELINE if "baseline" in exp_key else COLOR_OURS
            label = exp_config["title"]

            # 绘制 global count
            if exp_key in all_results:
                res = sorted(all_results[exp_key], key=lambda x: x['step'])
                steps = [r['step'] for r in res]
                count_key, _, _ = _get_diversity_key_and_label(
                    res[0], n, 'global_count')

                if count_key:
                    counts = [r.get(count_key, 0) for r in res]

                    # 截取到 600 步
                    steps = np.array(steps)
                    counts = np.array(counts)
                    mask = steps <= 600
                    steps = steps[mask]
                    counts = counts[mask]

                    # 添加阴影带
                    counts_array = np.array(counts)
                    window = min(5, len(counts) //
                                 10) if len(counts) > 10 else 1
                    if window > 1:
                        std_approx = np.convolve(
                            np.abs(np.diff(np.concatenate(
                                [[counts[0]], counts]))),
                            np.ones(window) / window,
                            mode='same'
                        ) * 1.2
                        ax_count.fill_between(
                            steps,
                            counts_array - std_approx,
                            counts_array + std_approx,
                            color=color,
                            alpha=0.2,
                            zorder=1,
                        )

                    ln = ax_count.plot(
                        steps,
                        counts,
                        color=color,
                        linewidth=3.0,
                        linestyle='-',
                        alpha=1.0,
                        zorder=4,
                        label=label,
                    )[0]
                    if len(legend_handles) == 0:
                        legend_handles.append(ln)

            # 绘制 entropy
            if exp_key in wandb_data:
                steps = wandb_data[exp_key]["steps"]
                entropy = wandb_data[exp_key]["entropy"]

                # 截取到 600 步
                steps = np.array(steps)
                entropy = np.array(entropy)
                mask = steps <= 600
                steps = steps[mask]
                entropy = entropy[mask]

                # 添加阴影带
                entropy_array = np.array(entropy)
                window = min(5, len(entropy) // 10) if len(entropy) > 10 else 1
                if window > 1:
                    std_approx = np.convolve(
                        np.abs(np.diff(np.concatenate(
                            [[entropy[0]], entropy]))),
                        np.ones(window) / window,
                        mode='same'
                    ) * 1.2
                    ax_entropy.fill_between(
                        steps,
                        entropy_array - std_approx,
                        entropy_array + std_approx,
                        color=color,
                        alpha=0.2,
                        zorder=1,
                    )

                ax_entropy.plot(
                    steps,
                    entropy,
                    color=color,
                    linewidth=3.0,
                    linestyle='-',
                    alpha=1.0,
                    zorder=4,
                    label=label,
                )

        # 添加全局图例
        # 从所有轴收集 handles 和 labels（确保包含所有实验）
        all_handles = []
        all_labels = []
        seen_labels = set()

        for ax in [ax_count, ax_entropy]:
            handles, labels = ax.get_legend_handles_labels()
            for h, l in zip(handles, labels):
                if l not in seen_labels:
                    all_handles.append(h)
                    all_labels.append(l)
                    seen_labels.add(l)

        if all_handles:
            fig.legend(
                all_handles,
                all_labels,
                loc='lower center',
                bbox_to_anchor=(0.5, -0.1),
                ncol=2,
                frameon=True,
                framealpha=0.95,
                edgecolor='#888888',
                fancybox=True,
                shadow=False,
                prop={'weight': 'bold', 'size': FONT_LEGEND},
            )

        # 调整布局（1行2列，调整间距）
        plt.subplots_adjust(wspace=0.20, top=0.85,
                            bottom=0.20, left=0.10, right=0.95)

        save = f"global_metrics_entropy_{n}gram.pdf"
        plt.savefig(
            os.path.join(output_dir, save),
            format='pdf',
            dpi=dpi,
            bbox_inches='tight',
            metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'},
        )
        plt.close(fig)
        print(f"已生成PDF图表: {os.path.join(output_dir, save)}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='绘制 global count 和 entropy 图表')
    parser.add_argument('--input', '-i', type=str,
                        default='/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/analysis_results.json')
    parser.add_argument('--inputs', type=str, nargs='+',
                        help='多个分批收集的 JSON 文件路径')
    parser.add_argument('--input-glob', type=str,
                        help='用通配符指定多个 JSON')
    parser.add_argument('--output-dir', '-o', type=str,
                        default='plots',
                        help='输出目录')
    parser.add_argument('--ngrams', type=int, nargs='+',
                        default=[10], help='要绘制的 n-gram 大小')
    parser.add_argument('--dpi', type=int, default=600, help='PDF清晰度（DPI）')

    args = parser.parse_args()

    if args.inputs:
        data_paths = args.inputs
    elif args.input_glob:
        data_paths = sorted(glob.glob(args.input_glob))
        if not data_paths:
            print(f"错误: input-glob 没有匹配到任何文件: {args.input_glob}")
            exit(1)
    else:
        data_paths = args.input

    print("=" * 60)
    print("生成 global count 和 entropy 图表...")
    print("=" * 60)
    plot_global_metrics_and_entropy(
        data_path=data_paths,
        output_dir=args.output_dir,
        ngram_sizes=args.ngrams,
        dpi=args.dpi,
    )
