from matplotlib.ticker import FuncFormatter
import numpy as np
import matplotlib.pyplot as plt
import json
import os
import argparse
import glob
import matplotlib
matplotlib.use('Agg')  # 使用Agg后端以支持高质量PDF输出

# 尝试导入 scipy 用于平滑，如果没有则使用 numpy 实现
try:
    from scipy.signal import savgol_filter
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

# 设置PDF输出参数
plt.rcParams['pdf.fonttype'] = 42  # TrueType字体，确保文字在PDF中可编辑
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['pdf.use14corefonts'] = False

# 全局配置：实验标题映射
MY_TITLES = {
    "baseline-dapo-math-redo": "GRPO High clip ratio",
    "baseline-grpo-dapo-math-minibsz32": "GRPO",
    "baseline-gspo-dapo-math-minibsz32": "GSPO",
    "skip-right-skip-limits10-gspo-dapo-math": "GSPO + LIE",  # "$\Delta L = 500$",
    "skip-right-skip-limits10-dapo-math": "GRPO High clip ratio + LIE",
    "skip-right-skip-limits10-grpo-dapo-math": "GRPO + LIE",
    "skip-right-skip-limits10-gspo-dapo-math-add2k": "GSPO + LIE (add2k)",
    "skip-right-skip-limits10-gspo-dapo-math-wo-repetition": "GSPO + Length",
    "skip-right-distinct-hard-gspo-dapo-math": "GSPO + LIE (distinct ratio hard)",
    "skip-right-skip-limits10-gspo-dapo-math-wo-repetition-redo": "GSPO + $R_{\\text{len}}$",
    "llama-baseline-gspo-deepmath": "GSPO",
    "llama-add1k-gspo-deepmath": "GSPO +LIE",
    "qwen3-4b-polaris-add1k-gspo": "GSPO + LIE",
    "qwen3-4b-polaris-baseline-gspo": "GSPO",
    "skip-right-skip-limits10-gspo-dapo-math-add1k5": "$\Delta L = 1k$",
    "ours-gspo-dapo-math-add8k5": "$\Delta L = 8k$",
    "ours-gspo-dapo-math-add1k-fixed": f"$\Delta L = 500 + $ avg (fixed)",
    "ours-gspo-dapo-math-add600": "$\Delta L = 100$",
    "sft-gspo-ours-dapo-math-max12k": "SFT GSPO + LIE",
    "sft4k-gspo-dapo-math-minibsz32-max12k": "SFT GSPO",
    "gspo_repetition": "GSPO + Redundancy Penalty",
    "skip-right-distinct_bonus-gspo-dapo-math": r"GSPO + Maximize $C_{\text{context}}$"
}

# 全局配置：分组逻辑
GROUPS = {
    # "Baselines": ["baseline-grpo-dapo-math-minibsz32", "baseline-gspo-dapo-math-minibsz32"],
    # "Qwen3-4b-Base-ours": ["skip-right-skip-limits10-gspo-dapo-math", "skip-right-skip-limits10-grpo-dapo-math"],
    # "GSPO": ["baseline-gspo-dapo-math-minibsz32", "skip-right-skip-limits10-gspo-dapo-math"],
    # "Ablation": ["baseline-gspo-dapo-math-minibsz32", "skip-right-skip-limits10-gspo-dapo-math-wo-repetition-redo"],
    # "llama": ["llama-baseline-gspo-deepmath", "llama-add1k-gspo-deepmath"],
    "Reward_ablation": ["baseline-gspo-dapo-math-minibsz32", "skip-right-skip-limits10-gspo-dapo-math", "gspo_repetition", "skip-right-skip-limits10-gspo-dapo-math-wo-repetition-redo"],
    "distinct_bonus": ["baseline-gspo-dapo-math-minibsz32", "skip-right-skip-limits10-gspo-dapo-math", "skip-right-distinct_bonus-gspo-dapo-math"],
    # "qwen3-4b": [ "qwen3-4b-polaris-baseline-gspo","qwen3-4b-polaris-add1k-gspo"],
    # "SFT": ["skip-right-skip-limits10-gspo-dapo-math","sft4k-gspo-dapo-math-minibsz32-max12k", "sft-gspo-ours-dapo-math-max12k"],
    # "delta_L":["baseline-gspo-dapo-math-minibsz32", "ours-gspo-dapo-math-add600", "skip-right-skip-limits10-gspo-dapo-math", "skip-right-skip-limits10-gspo-dapo-math-add1k5","ours-gspo-dapo-math-add8k5", ]

}

# 定义格式化函数，将大数字转换为k格式


def thousands_formatter(x, pos):
    """将数字格式化为k格式，500以上都用k"""

    return f'{x/1000:.1f}k'


def millions_formatter(x, pos):
    """将数字格式化为M格式（百万）"""
    return f'{x/1_000_000:.2f}M'

# 定义格式化函数：显示原始整数（带逗号），用于 global distinct count


def int_commas_formatter(x, pos):
    """将数字格式化为带千分位分隔符的整数（不缩写）。"""
    try:
        return f'{int(round(x)):,}'
    except Exception:
        return str(x)


def _compact_count_formatter_factory(max_abs_value: float):
    """
    根据量级自动选择 y 轴显示：
    - >= 1e6: 用 M
    - >= 1e3: 用 k
    - 其它: 用整数
    主要用于 global_count，避免显示大量 0。
    """
    max_abs_value = float(max_abs_value) if max_abs_value is not None else 0.0

    def _fmt(x, pos):
        try:
            x = float(x)
        except Exception:
            return str(x)

        if max_abs_value >= 1_000_000:
            return f'{x/1_000_000:.2f}M'
        if max_abs_value >= 1_000:
            return f'{x/1_000:.1f}k'
        return f'{int(round(x))}'

    return _fmt


def _smooth_data(data, window_size=None, method='savgol'):
    """
    平滑数据曲线

    Args:
        data: 一维数组
        window_size: 平滑窗口大小（如果为 None，则自动计算）
        method: 'savgol' (Savitzky-Golay filter) 或 'moving_avg' (移动平均)

    Returns:
        平滑后的数据
    """
    data = np.array(data)
    if len(data) < 3:
        return data
    window_size = 5
    # if window_size is None:
    #     # 自动计算窗口大小：约为数据长度的 5-10%，但至少为 3，且必须为奇数
    #     window_size = max(3, min(len(data) // 10, len(data) - 1))
    #     if window_size % 2 == 0:
    #         window_size += 1

    if method == 'savgol' and HAS_SCIPY:
        try:
            # Savitzky-Golay filter: 更平滑，保持峰值特征
            poly_order = min(3, window_size - 1)  # 多项式阶数
            return savgol_filter(data, window_size, poly_order)
        except Exception:
            # 如果失败，回退到移动平均
            method = 'moving_avg'

    if method == 'moving_avg':
        # 简单的移动平均
        kernel = np.ones(window_size) / window_size
        smoothed = np.convolve(data, kernel, mode='same')
        # 处理边界：保持首尾值不变
        half = window_size // 2
        smoothed[:half] = data[:half]
        smoothed[-half:] = data[-half:]
        return smoothed

    return data


def _add_top_headroom(ylim, frac: float = 0.08, cap_upper=None):
    """
    只给 y 轴上边界增加一点留白（headroom），避免左上角 annotation 框与曲线/阴影贴得太近。
    frac: 以当前 ylim span 的比例增加。
    cap_upper: 若给定（如 1.0），则上界不会超过该值。
    """
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
    except Exception:
        return ylim
    span = max(y1 - y0, 1e-9)
    y1_new = y1 + span * float(frac)
    if cap_upper is not None:
        try:
            y1_new = min(float(cap_upper), y1_new)
        except Exception:
            pass
    return (y0, y1_new)


# 定义百分比格式化函数（不显示%符号）
def percent_formatter(x, pos):
    """将小数格式化为百分比数字（不带%）"""
    return f'{x*100:.0f}'


def _ratio_percent_formatter_factory(ylim):
    """
    ratio（通常在 0~1）用百分比展示时，若数值很小（例如 0.01~0.02），
    直接用整数百分比会导致大量刻度都显示成同一个数字（例如都显示 2）。
    这里根据当前 y 轴范围自适应保留小数位，避免“坐标刻度重复/怪异”。
    """
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
        span_pct = abs((y1 - y0) * 100.0)
        max_pct = max(abs(y0), abs(y1)) * 100.0
    except Exception:
        span_pct, max_pct = 0.0, 0.0

    # 经验规则：
    # - 如果最大值 < 10% 或范围 < 3%，用 1 位小数（2.1, 2.3）区分刻度
    # - 如果最大值 < 1% 或范围 < 0.5%，用 2 位小数（0.85, 0.92）
    if max_pct < 1.0 or span_pct < 0.5:
        fmt = '{:.2f}'
    elif max_pct < 10.0 or span_pct < 3.0:
        fmt = '{:.1f}'
    else:
        fmt = '{:.0f}'

    def _fmt(x, pos):
        try:
            return fmt.format(float(x) * 100.0)
        except Exception:
            return str(x)

    return _fmt


def _pick_first_existing_key(record: dict, candidates):
    """在给定候选 key 列表中，返回第一个存在于 record 的 key；都不存在则返回 None。"""
    for k in candidates:
        if k in record:
            return k
    return None


def _get_diversity_key_and_label(record0: dict, n: int, diversity_metric: str):
    """
    返回 (key, y_label, kind)，其中 kind ∈ {'count','ratio'}，用于决定 y 轴范围与格式化。

    diversity_metric 支持：
    - 'count' / 'ratio'：原有 per-step 的 avg distinct n-gram
    - 'global_count' / 'global_ratio'：新增的全局 distinct 指标
    """
    if diversity_metric in ('count', 'ratio'):
        metric_suffix = diversity_metric
        key = f'distinct_{n}gram_{metric_suffix}' if f'distinct_{n}gram_{metric_suffix}' in record0 else f'{n}gram_{metric_suffix}'
        # 按 draft：Trajectory-level
        # C_distinct(τ) / R_distinct(τ)
        y_label = r'$\mathbf{C_{\mathrm{context}}(\tau)}$' if diversity_metric == 'count' else r'$\mathbf{R_{\mathrm{context}}(\tau)}$'
        kind = diversity_metric
        return key, y_label, kind

    if diversity_metric == 'global_count':
        key = _pick_first_existing_key(
            record0,
            [
                f'step_global_distinct_{n}gram_count'
            ],
        )
        # 按 draft：Global-level（记号与 ratio 一致）
        y_label = r'$\mathbf{C_{\mathrm{global}}(\mathcal{T})}$'
        return key, y_label, 'count'

    if diversity_metric == 'global_ratio':
        key = _pick_first_existing_key(
            record0,
            [
                f'step_global_distinct_{n}gram_ratio'
            ],
        )
        y_label = r'$\mathbf{R_{\mathrm{global}}(\mathcal{T})}$'
        return key, y_label, 'ratio'

    raise ValueError(f"Unknown diversity_metric: {diversity_metric}")


def _load_and_merge_results(data_paths):
    """加载多份 JSON 并合并（假设结构一致：{exp_name: [records...] }）。"""
    merged: dict = {}
    for p in data_paths:
        if not os.path.exists(p):
            print(f"警告: 找不到数据文件 {p}，跳过")
            continue
        try:
            with open(p, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"加载 JSON 失败: {p}，错误: {e}")
            continue

        if not isinstance(data, dict):
            print(f"警告: {p} 的顶层不是 dict（期望 {{exp_name: [..]}}），跳过")
            continue

        # 如果同名 experiment 冲突，则后加载的文件覆盖前面的（最简单且可预期）
        overlap = set(merged.keys()) & set(data.keys())
        if overlap:
            print(
                f"警告: {p} 与已有数据存在重复 experiment key，将覆盖: {sorted(list(overlap))}")
        merged.update(data)

    return merged


def plot_two_figs_count_ratio_row(
    data_path,
    output_dir='plots',
    ngram_sizes=[10],
    dpi=600,
    layout_mode='1x4',
    max_step=None,
):
    """
    按照 groups 划分，每个 group 生成一张图：
    - 每张图包含该 group 中所有实验的 count 和 ratio（支持任意数量的 run）
    - 每个实验占两列：count 和 ratio
    - 每个子图为三轴叠加图 (Diversity, Length, Accuracy)：
    - 左轴：Diversity（蓝）
    - 右轴1：Length（紫）
    - 右轴2：Accuracy（红）

    Args:
        layout_mode: '1x4' 或 '2x2'
            - '1x4': 一行多列布局，每个实验占两列（count 和 ratio），支持任意数量的 run
            - '2x2': 2行2列布局（左上C_context, 右上R_context, 左下Length, 右下Accuracy）
            注意：2x2 模式为每个实验单独生成一张图，支持任意数量的 run
        max_step: 如果指定，只显示 step <= max_step 的数据点（用于统一截取数据）
    """
    # 兼容：既支持单个 json，也支持多个 json
    if isinstance(data_path, (list, tuple)):
        data_paths = list(data_path)
    else:
        data_paths = [data_path]

    all_results = _load_and_merge_results(data_paths)
    if not all_results:
        print("错误: 没有成功加载任何数据")
        return

    # 如果指定了 max_step，过滤所有结果
    if max_step is not None:
        print(f"截取数据到 step <= {max_step}")
        filtered_results = {}
        for exp_key, records in all_results.items():
            filtered_records = [
                r for r in records if r.get('step', 0) <= max_step]
            if filtered_records:
                filtered_results[exp_key] = filtered_records
        all_results = filtered_results
        if not all_results:
            print("错误: 过滤后没有剩余数据")
            return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 颜色配置（与其它图保持一致）
    COLOR_NGRAM = '#00468B'  # 深蓝色 (Diversity)
    COLOR_LENGTH = '#9B59B6'  # 优雅紫色 (Length)
    COLOR_ACC = '#AE1029'    # 绯红色 (Accuracy)

    # 使用全局的 groups 配置
    groups = GROUPS

    metric_labels = {
        'count': r'$\mathbf{C_{\mathrm{context}}(\tau)}$',
        'ratio': r'$\mathbf{R_{\mathrm{context}}(\tau)}$',
    }

    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['axes.linewidth'] = 1.8

    # 字体统一放大（仅影响本函数生成的图）
    FONT_TICK = 26
    FONT_XLABEL = 30
    FONT_METRIC_TITLE = 36
    FONT_GROUP_TITLE = 42
    FONT_YLABEL = 30
    FONT_LEGEND = 32

    def _compute_len_acc_ylims(n: int, exp_keys: list):
        """跨指定实验统一计算 Length 和 Accuracy 的 y 轴范围。"""
        all_lengths, all_accs = [], []
        for exp_key in exp_keys:
            if exp_key not in all_results:
                continue
            res = sorted(all_results[exp_key], key=lambda x: x['step'])
            all_lengths.extend([r.get('avg_token_length', 0) for r in res])
            all_accs.extend([r.get('accuracy', 0) for r in res])

        if all_lengths:
            len_min, len_max = min(all_lengths), max(all_lengths)
            len_range = float(len_max - len_min)
            if abs(len_range) < 1e-9:
                base = max(abs(float(len_max)), 1.0)
                len_margin = base * 0.08
                len_margin_low = base * 0.04
            else:
                len_margin = len_range * 0.2
                len_margin_low = len_range * 0.1
            len_ylim = (max(0, len_min - len_margin_low), len_max + len_margin)
            len_ylim = _add_top_headroom(len_ylim, frac=0.08, cap_upper=None)
        else:
            len_ylim = (0, 4000)

        if all_accs:
            acc_min, acc_max = min(all_accs), max(all_accs)
            acc_range = float(acc_max - acc_min)
            if abs(acc_range) < 1e-9:
                base = max(abs(float(acc_max)), 1.0)
                acc_margin = base * 0.08
                acc_margin_low = base * 0.04
            else:
                acc_margin = acc_range * 0.2
                acc_margin_low = acc_range * 0.1
            acc_ylim = (max(0, acc_min - acc_margin_low),
                        min(1.0, acc_max + acc_margin))
            acc_ylim = _add_top_headroom(acc_ylim, frac=0.08, cap_upper=1.0)
        else:
            acc_ylim = (0, 1.0)

        return len_ylim, acc_ylim

    def _compute_metric_ylim(n: int, metric: str, exp_keys: list):
        """跨指定实验统一计算该 metric 的 y 轴范围。"""
        vals = []
        for exp_key in exp_keys:
            if exp_key not in all_results:
                continue
            res = sorted(all_results[exp_key], key=lambda x: x['step'])
            div_key, _, div_kind = _get_diversity_key_and_label(
                res[0], n, metric)
            if not div_key:
                continue
            vals.extend([r.get(div_key, 0) for r in res])

        if not vals:
            return None, None

        vmin, vmax = min(vals), max(vals)
        vrange = float(vmax - vmin)
        if abs(vrange) < 1e-9:
            base = max(abs(float(vmax)), 1.0)
            margin = base * 0.08
            margin_low = base * 0.04
        else:
            margin = vrange * 0.2
            margin_low = vrange * 0.1

        # kind 用任意一个存在的 run 来判断即可（count/ratio）
        any_exp_key = next((k for k in exp_keys if k in all_results), None)
        res0 = all_results[any_exp_key][0] if any_exp_key else {}
        _, _, div_kind = _get_diversity_key_and_label(res0, n, metric)

        if div_kind == 'ratio':
            ylim = (max(0, vmin - margin_low), min(1.0, vmax + margin))
            ylim = _add_top_headroom(ylim, frac=0.08, cap_upper=1.0)
        else:
            ylim = (max(0, vmin - margin_low), vmax + margin)
            ylim = _add_top_headroom(ylim, frac=0.08, cap_upper=None)
        return ylim, div_kind

    def _style_ax(ax):
        ax.set_facecolor('#FAFAFA')
        for side in ['left', 'bottom', 'top', 'right']:
            ax.spines[side].set_linewidth(1.8)
            ax.spines[side].set_color('#000000')
            ax.spines[side].set_visible(True)
        ax.tick_params(axis='both', labelcolor='#000000',
                       labelsize=FONT_TICK, length=6, width=1.5)
        ax.grid(True, axis='both', alpha=0.3, color='#CCCCCC',
                linewidth=0.8, linestyle='-', zorder=0)
        ax.set_axisbelow(True)
        ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=6, integer=True))
        ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))

    def _plot_one(
        ax,
        n: int,
        exp_key: str,
        metric: str,
        div_ylim,
        div_kind: str,
        len_ylim,
        acc_ylim,
        show_len_axis: bool,
        show_acc_axis: bool,
    ):
        res = sorted(all_results[exp_key], key=lambda x: x['step'])
        steps = [r['step'] for r in res]
        div_key, _, _ = _get_diversity_key_and_label(res[0], n, metric)
        if not div_key:
            print(f"警告: {exp_key} 缺少 {metric} 对应 key，已用 0 填充")
            divs = [0 for _ in res]
        else:
            divs = [r.get(div_key, 0) for r in res]

        _style_ax(ax)
        # 子图标题使用metric标签（作为y轴标题）
        ax.set_title(metric_labels[metric], fontsize=FONT_METRIC_TITLE,
                     fontweight='bold', pad=14, color='#000000')
        ax.set_xlabel('Training Step', fontsize=FONT_XLABEL,
                      fontweight='bold', labelpad=10, color='#000000')

        # 阴影带（与其它图一致的“近似 std”）
        divs_array = np.array(divs)
        window = min(5, len(divs) // 10) if len(divs) > 10 else 1
        if window > 1:
            std_approx = np.convolve(
                np.abs(np.diff(np.concatenate([[divs[0]], divs]))),
                np.ones(window) / window,
                mode='same'
            ) * 1.2
            ax.fill_between(
                steps,
                divs_array - std_approx,
                divs_array + std_approx,
                color=COLOR_NGRAM,
                alpha=0.2,
                zorder=1,
            )

        # 线条句柄用于全局 legend
        ln_div = ax.plot(
            steps,
            divs,
            color=COLOR_NGRAM,
            linewidth=3.0,
            linestyle='-',
            alpha=1.0,
            zorder=4,
            label='Diversity Metric',
        )[0]

        if div_ylim is not None:
            ax.set_ylim(div_ylim)

        # y 轴格式化（与现有逻辑保持一致）
        if div_kind == 'ratio':
            ax.yaxis.set_major_formatter(FuncFormatter(
                _ratio_percent_formatter_factory(div_ylim if div_ylim else (0, 1))))
        else:
            if metric == 'global_count':
                max_abs = max(abs(float(div_ylim[0])), abs(
                    float(div_ylim[1]))) if div_ylim else 0.0
                ax.yaxis.set_major_formatter(FuncFormatter(
                    _compact_count_formatter_factory(max_abs)))
            else:
                ax.yaxis.set_major_formatter(
                    FuncFormatter(thousands_formatter))

        # --- 右轴1：Length ---
        ax2 = ax.twinx()
        ax2.spines['top'].set_linewidth(1.8)
        ax2.spines['left'].set_visible(False)
        ax2.spines['top'].set_color('#000000')

        lengths = [r.get('avg_token_length', 0) for r in res]
        lengths_array = np.array(lengths)
        window = min(5, len(lengths) // 10) if len(lengths) > 10 else 1
        if window > 1:
            std_approx = np.convolve(
                np.abs(np.diff(np.concatenate([[lengths[0]], lengths]))),
                np.ones(window) / window,
                mode='same'
            ) * 1.2
            ax2.fill_between(
                steps,
                lengths_array - std_approx,
                lengths_array + std_approx,
                color=COLOR_LENGTH,
                alpha=0.2,
                zorder=2,
            )

        ln_len = ax2.plot(
            steps,
            lengths,
            color=COLOR_LENGTH,
            linewidth=3.0,
            alpha=1.0,
            linestyle='-',
            zorder=5,
            label='$\mathbf{L}$',
        )[0]
        ax2.set_ylim(len_ylim)
        ax2.yaxis.set_major_formatter(FuncFormatter(thousands_formatter))
        ax2.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))

        if show_len_axis:
            # y轴标签已改为标题，不再显示ylabel
            ax2.spines['right'].set_linewidth(1.8)
            ax2.spines['right'].set_color('#000000')
            ax2.tick_params(axis='y', labelcolor='#000000',
                            labelsize=FONT_TICK, length=6, width=1.5)
        else:
            ax2.set_yticklabels([])
            ax2.tick_params(axis='y', right=False, left=False,
                            labelleft=False, labelright=False)
            ax2.spines['right'].set_visible(False)

        # --- 右轴2：Accuracy ---
        ax3 = ax.twinx()
        ax3.spines['top'].set_linewidth(1.8)
        ax3.spines['left'].set_visible(False)
        ax3.spines['top'].set_color('#000000')

        accs = [r.get('accuracy', 0) for r in res]
        accs_array = np.array(accs)
        window = min(5, len(accs) // 10) if len(accs) > 10 else 1
        if window > 1:
            std_approx = np.convolve(
                np.abs(np.diff(np.concatenate([[accs[0]], accs]))),
                np.ones(window) / window,
                mode='same'
            ) * 0.5
            ax3.fill_between(
                steps,
                np.maximum(0, accs_array - std_approx),
                np.minimum(1.0, accs_array + std_approx),
                color=COLOR_ACC,
                alpha=0.2,
                zorder=3,
            )

        ln_acc = ax3.plot(
            steps,
            accs,
            color=COLOR_ACC,
            linewidth=3.0,
            alpha=1.0,
            linestyle='-',
            zorder=6,
            label='Accuracy',
        )[0]
        ax3.set_ylim(acc_ylim)
        ax3.yaxis.set_major_formatter(FuncFormatter(percent_formatter))
        ax3.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))

        if show_acc_axis:
            ax3.spines['right'].set_linewidth(1.8)
            ax3.spines['right'].set_color('#000000')
            # y轴标签已改为标题，不再显示ylabel
            ax3.tick_params(axis='y', labelcolor='#000000',
                            labelsize=FONT_TICK, length=6, width=1.5)
        else:
            ax3.set_yticklabels([])
            ax3.tick_params(axis='y', right=False, left=False,
                            labelleft=False, labelright=False)
            ax3.spines['right'].set_visible(False)

        return ln_div, ln_len, ln_acc

    def _add_experiment_titles(fig, axes_row, exp_keys):
        """在每个实验的两列（count + ratio）上方添加实验标题。支持任意数量的 run。"""
        # 轴的位置在 subplots_adjust 之后是稳定的
        positions = [ax.get_position() for ax in axes_row]
        if not positions:
            return
        y = max(p.y1 for p in positions) + 0.075

        # 每个实验占两列（count 和 ratio）
        for i, exp_key in enumerate(exp_keys):
            col_start = i * 2
            # 确保有足够的列来显示 count 和 ratio
            if col_start < len(positions):
                if col_start + 1 < len(positions):
                    # 正常情况：有 count 和 ratio 两列
                    p0, p1 = positions[col_start], positions[col_start + 1]
                    x_center = (p0.x0 + p1.x1) / 2.0
                else:
                    # 边界情况：只有一列（不应该发生，但为了健壮性处理）
                    p0 = positions[col_start]
                    x_center = (p0.x0 + p0.x1) / 2.0
                title = MY_TITLES.get(exp_key, exp_key)
                fig.text(x_center, y, title, ha='center', va='bottom',
                         fontsize=FONT_GROUP_TITLE, fontweight='bold')

    def _add_vertical_separators(fig, axes_row, num_experiments):
        """在每个实验的两列之间加垂直虚线分隔。"""
        positions = [ax.get_position() for ax in axes_row]

        # 每个实验占两列，在每对之间添加分隔线
        for i in range(num_experiments - 1):
            col_end = (i + 1) * 2 - 1
            col_start = (i + 1) * 2
            if col_end < len(positions) and col_start < len(positions):
                p_end = positions[col_end]
                p_start = positions[col_start]
                x = (p_end.x1 + p_start.x0) / 2.0
                y0 = min(p.y0 for p in positions) - 0.08
                y1 = max(p.y1 for p in positions) + 0.08
                line = plt.Line2D(
                    [x, x],
                    [y0, y1],
                    transform=fig.transFigure,
                    linestyle='--',
                    linewidth=5,
                    color='#666666',
                    alpha=1.0,
                    zorder=10,
                )
                fig.add_artist(line)

    def _add_global_legend(fig, handles):
        """底部全局 legend：Diversity / Length / Accuracy。"""
        if not handles:
            return
        legend = fig.legend(
            handles,
            [h.get_label() for h in handles],
            loc='lower center',
            bbox_to_anchor=(0.5, 0.05),
            ncol=3,
            frameon=True,
            framealpha=0.95,
            edgecolor='#888888',
            fancybox=True,
            shadow=False,
            prop={'weight': 'bold', 'size': FONT_LEGEND},
        )
        legend.get_frame().set_linewidth(1.5)

    for n in ngram_sizes:
        # 对每个 group 生成一张图
        for group_name, exp_keys in groups.items():
            # 过滤出存在的实验
            valid_keys = [k for k in exp_keys if k in all_results]
            if not valid_keys:
                print(f"警告: group={group_name}, n={n} 没有可用的实验数据，跳过")
                continue

            # 统一的 Length/Accuracy y 轴范围（跨该 group 的所有实验）
            len_ylim, acc_ylim = _compute_len_acc_ylims(n, valid_keys)

            # 计算 count 和 ratio 的 y 轴范围
            count_ylim, count_kind = _compute_metric_ylim(
                n, 'count', valid_keys)
            ratio_ylim, ratio_kind = _compute_metric_ylim(
                n, 'ratio', valid_keys)

            if layout_mode == '2x2':
                # 2x2 模式：为每个实验生成一个 2x2 的图
                # 调整尺寸使子图更方正（接近1:1高宽比）
                base_width = 18.0
                base_height = 14.0  # 增加高度，使子图更方正
                for exp_idx, exp_key in enumerate(valid_keys):
                    # 创建 2x2 子图
                    fig, axes = plt.subplots(2, 2, figsize=(
                        base_width, base_height), squeeze=True)
                    ax_count, ax_ratio = axes[0, 0], axes[0, 1]
                    ax_length, ax_acc = axes[1, 0], axes[1, 1]

                    res = sorted(all_results[exp_key], key=lambda x: x['step'])
                    steps = [r['step'] for r in res]

                    # 获取数据
                    count_key, _, _ = _get_diversity_key_and_label(
                        res[0], n, 'count')
                    ratio_key, _, _ = _get_diversity_key_and_label(
                        res[0], n, 'ratio')
                    counts = [r.get(count_key, 0)
                              if count_key else 0 for r in res]
                    ratios = [r.get(ratio_key, 0)
                              if ratio_key else 0 for r in res]
                    lengths = [r.get('avg_token_length', 0) for r in res]
                    accs = [r.get('accuracy', 0) for r in res]

                    exp_label = MY_TITLES.get(exp_key, exp_key)

                    # 绘制四个子图（应用与1x4模式一致的美化效果）
                    def _plot_simple(ax, data, ylim, ylabel, color, formatter=None):
                        # 应用平滑处理
                        data_smooth = _smooth_data(data)

                        # 应用与1x4模式一致的样式
                        ax.set_facecolor('#FFFFFF')  # 纯白背景
                        for side in ['left', 'bottom', 'top', 'right']:
                            ax.spines[side].set_linewidth(2.0)  # 增加边框宽度
                            ax.spines[side].set_color('#000000')  # 黑色边框
                            ax.spines[side].set_visible(True)
                        ax.tick_params(axis='both', labelcolor='#000000',
                                       labelsize=FONT_TICK, length=8, width=2.0)
                        ax.grid(True, axis='both', alpha=0.25, color='#999999',
                                linewidth=1.0, linestyle='--', zorder=0)
                        ax.set_axisbelow(True)
                        ax.xaxis.set_major_locator(
                            plt.MaxNLocator(nbins=6, integer=True))
                        ax.yaxis.set_major_locator(
                            plt.MaxNLocator(nbins=5, integer=False))

                        # 先绘制白色背景线（增强对比度）
                        ax.plot(steps, data_smooth, color='white',
                                linewidth=5.0, zorder=3, alpha=0.8)
                        # 再绘制主线条（应用美化效果）
                        ax.plot(steps, data_smooth, color=color, linewidth=4.0, zorder=4, alpha=1.0,
                                solid_capstyle='round', solid_joinstyle='round')

                        ax.set_xlabel('Training Step', fontsize=FONT_XLABEL,
                                      fontweight='bold', labelpad=5, color='#000000')
                        # 2x2模式下使用y轴标签而不是标题
                        ax.set_ylabel(ylabel, fontsize=FONT_METRIC_TITLE,
                                      fontweight='bold', labelpad=10, color='#000000')
                        ax.set_ylim(ylim)
                        if formatter:
                            ax.yaxis.set_major_formatter(
                                FuncFormatter(formatter))
                        ax.tick_params(axis='y', labelcolor='#000000')

                    _plot_simple(ax_count, counts, count_ylim,
                                 r'$C_{\mathrm{context}}(\tau)$', COLOR_NGRAM, thousands_formatter)
                    _plot_simple(ax_ratio, ratios, ratio_ylim, r'$R_{\mathrm{context}}(\tau)$', COLOR_NGRAM,
                                 lambda x, pos: _ratio_percent_formatter_factory(ratio_ylim)(x, pos))
                    _plot_simple(ax_length, lengths, len_ylim,
                                 '$L$', COLOR_LENGTH, thousands_formatter)
                    _plot_simple(ax_acc, accs, acc_ylim, 'Accuracy',
                                 COLOR_ACC, percent_formatter)

                    # 添加实验标题
                    fig.suptitle(exp_label, fontsize=FONT_GROUP_TITLE,
                                 fontweight='bold', y=0.98)

                    # 调整布局：优化间距，使子图更方正协调
                    plt.subplots_adjust(
                        wspace=0.30, hspace=0.55, top=0.92, bottom=0.12, left=0.10, right=0.95)

                    save = f"local_count_ratio_2x2_{group_name.lower()}_{exp_key}_{n}gram.pdf"
                    plt.savefig(
                        os.path.join(output_dir, save),
                        format='pdf',
                        dpi=dpi,
                        bbox_inches='tight',
                        metadata={'Creator': 'matplotlib',
                                  'Producer': 'matplotlib'},
                    )
                    plt.close(fig)
                    print(f"已生成PDF图表: {os.path.join(output_dir, save)}")

            else:
                # 1x4 模式：当前模式，每个实验占两列（count 和 ratio）
                num_cols = 2 * len(valid_keys)

                # 根据 run 数量动态调整图形宽度和间距
                # 当 run 数量较多时，适当减小每个子图的宽度，并增加间距
                # 减小 wspace 以增大子图宽度，减小子图间距
                if len(valid_keys) <= 2:
                    subplot_width = 9.0
                    wspace = 0.15
                elif len(valid_keys) <= 4:
                    subplot_width = 8.0
                    wspace = 0.12
                else:
                    subplot_width = 7.0
                    wspace = 0.10

                # 保持子图长宽比为1:1，与2x2模式一致
                # 每个子图高度 = 每个子图宽度，总高度 = 子图宽度
                subplot_height = subplot_width - 2.0
                fig, axes = plt.subplots(1, num_cols, figsize=(
                    subplot_width * num_cols, subplot_height), squeeze=True)

                # 确保 axes 是数组（当只有一个子图时，squeeze=True 会返回单个对象）
                if num_cols == 1:
                    axes = [axes]
                else:
                    axes = list(axes)

                # 构建列配置：每个实验占两列（count 和 ratio）
                cols = []
                for exp_key in valid_keys:
                    cols.append((exp_key, 'count', count_ylim, count_kind))
                    cols.append((exp_key, 'ratio', ratio_ylim, ratio_kind))

                legend_handles = None
                for col_idx, (ax, (exp_key, metric, div_ylim, div_kind)) in enumerate(zip(axes, cols)):
                    ln_div, ln_len, ln_acc = _plot_one(
                        ax,
                        n=n,
                        exp_key=exp_key,
                        metric=metric,
                        div_ylim=div_ylim,
                        div_kind=div_kind,
                        len_ylim=len_ylim,
                        acc_ylim=acc_ylim,
                        show_len_axis=(col_idx == 0),      # 只在最左显示 Length 轴
                        # 只在最右显示 Accuracy 轴
                        show_acc_axis=(col_idx == num_cols - 1),
                    )
                    if legend_handles is None:
                        legend_handles = [ln_div, ln_len, ln_acc]

                # 增大列间距（wspace），并给顶部留出组标题空间
                # 字体变大后：加大上下留白，避免组标题/图例挤压
                # 使用动态计算的 wspace
                plt.subplots_adjust(wspace=wspace, top=0.72,
                                    bottom=0.22, left=0.05, right=0.99)
                _add_experiment_titles(fig, axes, valid_keys)
                _add_vertical_separators(fig, axes, len(valid_keys))
                _add_global_legend(fig, legend_handles)
                save = f"local_count_ratio_row_{group_name.lower()}_{n}gram.pdf"
                plt.savefig(
                    os.path.join(output_dir, save),
                    format='pdf',
                    dpi=dpi,
                    bbox_inches='tight',
                    metadata={'Creator': 'matplotlib',
                              'Producer': 'matplotlib'},
                )
                plt.close(fig)
                print(f"已生成PDF图表: {os.path.join(output_dir, save)}")


def plot_two_runs_comparison(
    data_path,
    output_dir='plots',
    ngram_sizes=[10],
    dpi=600,
    run_pairs=None,
    max_step=None,
    layout_mode='1x4',
):
    """
    生成多个 run 的对比图：
    - layout_mode='1x4': 1行4列布局
    - 左1：C_context (count)
    - 左2：R_context (ratio)
    - 左3：Length
    - 左4：Accuracy
    - layout_mode='2x2': 2行2列布局
      - 左上：C_context (count)
      - 右上：R_context (ratio)
      - 左下：Length
      - 右下：Accuracy
    每个子图显示所有 run 的曲线，支持任意数量的 run。

    Args:
        run_pairs: 列表，每个元素是一个包含多个 run key 的元组，例如：
                  [("baseline-gspo-dapo-math-minibsz32", "skip-right-skip-limits10-gspo-dapo-math", ...), group_name]
                  或者 [(["run1", "run2", ...], group_name), ...]
                  如果为 None，则使用 groups 中的每个 group 的所有 run
        max_step: 如果指定，只显示 step <= max_step 的数据点（用于统一截取数据）
        layout_mode: '1x4' 或 '2x2'，指定布局模式
    """
    # 兼容：既支持单个 json，也支持多个 json
    if isinstance(data_path, (list, tuple)):
        data_paths = list(data_path)
    else:
        data_paths = [data_path]

    all_results = _load_and_merge_results(data_paths)
    if not all_results:
        print("错误: 没有成功加载任何数据")
        return

    # 如果指定了 max_step，过滤所有结果
    if max_step is not None:
        print(f"截取数据到 step <= {max_step}")
        filtered_results = {}
        for exp_key, records in all_results.items():
            filtered_records = [
                r for r in records if r.get('step', 0) <= max_step]
            if filtered_records:
                filtered_results[exp_key] = filtered_records
        all_results = filtered_results
        if not all_results:
            print("错误: 过滤后没有剩余数据")
            return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 颜色配置
    COLOR_NGRAM = '#00468B'  # 深蓝色 (Diversity)
    COLOR_LENGTH = '#9B59B6'  # 优雅紫色 (Length)
    COLOR_ACC = '#AE1029'    # 绯红色 (Accuracy)

    # 多个 run 的颜色调色板（同一蓝色系，不同深浅）
    # 使用纯蓝色系渐变，从深蓝到浅蓝，确保颜色统一且清晰易辨
    # 所有颜色都经过优化，确保在浅色背景上有足够的对比度
    # '#1f77b4',  # 明亮蓝色（Matplotlib 经典蓝）
    # '#ff7f0e',  # 鲜艳橙色（Matplotlib 经典橙）
    # '#2ca02c',  # 明亮绿色（Matplotlib 经典绿）
    # '#d62728',  # 鲜艳红色（Matplotlib 经典红）
    # '#9467bd',  # 明亮紫色（Matplotlib 经典紫）
    # '#8c564b',  # 棕色（Matplotlib 经典棕）
    # '#e377c2',  # 粉红色（Matplotlib 经典粉）
    # '#7f7f7f',  # 灰色（Matplotlib 经典灰）
    # '#bcbd22',  # 橄榄绿（Matplotlib 经典橄榄）
    # '#17becf',  # 青色（Matplotlib 经典青）
    RUN_COLORS = [
        # '#7f7f7f', # GSPO
        # '#9ECAE1', # 100
        # '#4292C6', # 500
        # '#08306B', # 8k
        # '#D62728'  # Fixed
        '#00468B',  # 深蓝色 (Diversity)
        '#9B59B6',  # 优雅紫色 (Length)
        '#AE1029',
        '#8c564b'
    ]

    # 使用全局的 groups 配置
    groups = GROUPS

    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['axes.linewidth'] = 1.8

    # 字体大小（调大）
    FONT_TICK = 26
    FONT_XLABEL = 30
    FONT_YLABEL = 30
    FONT_TITLE = 36
    FONT_LEGEND = 32

    # 如果没有指定 run_pairs，则从 groups 中生成（包含所有 run）
    if run_pairs is None:
        run_groups = []
        for group_name, exp_keys in groups.items():
            valid_keys = [k for k in exp_keys if k in all_results]
            if len(valid_keys) >= 1:
                run_groups.append((valid_keys, group_name))
            else:
                print(f"警告: group={group_name} 没有有效 run，跳过")
    else:
        # 如果指定了 run_pairs，确保格式正确
        run_groups = []
        for item in run_pairs:
            if isinstance(item, tuple) and len(item) == 2:
                run_keys, group_name = item
                if isinstance(run_keys, (list, tuple)):
                    run_groups.append((list(run_keys), group_name))
                else:
                    # 兼容旧格式：两个 run key
                    run_groups.append(([item[0], item[1]], group_name))
            elif isinstance(item, tuple) and len(item) >= 2:
                # 格式：(run1, run2, ..., group_name) 或 (run1, run2, group_name)
                if len(item) == 3 and isinstance(item[2], str):
                    # 可能是 (run1, run2, group_name) 格式
                    run_groups.append(([item[0], item[1]], item[2]))
                else:
                    # 多个 run key，最后一个作为 group_name
                    run_groups.append((list(item[:-1]), item[-1]))
            else:
                print(f"警告: 无法解析 run_pairs 项: {item}，跳过")
        run_pairs = run_groups

    for n in ngram_sizes:
        for run_keys, group_name in run_groups:
            # 过滤出存在的 run
            valid_run_keys = [k for k in run_keys if k in all_results]
            if not valid_run_keys:
                print(f"警告: group={group_name} 没有有效的 run，跳过")
                continue

            if len(valid_run_keys) == 0:
                continue

            # 为每个 run 准备数据
            all_runs_data = []
            for run_key in valid_run_keys:
                res = sorted(all_results[run_key], key=lambda x: x['step'])
                steps = [r['step'] for r in res]

            # 获取 count 和 ratio 的 key
                count_key, _, _ = _get_diversity_key_and_label(
                    res[0], n, 'count')
                ratio_key, _, _ = _get_diversity_key_and_label(
                    res[0], n, 'ratio')

                counts = [r.get(count_key, 0) if count_key else 0 for r in res]
                ratios = [r.get(ratio_key, 0) if ratio_key else 0 for r in res]
                lengths = [r.get('avg_token_length', 0) for r in res]
                accs = [r.get('accuracy', 0) for r in res]

                all_runs_data.append({
                    'key': run_key,
                    'label': MY_TITLES.get(run_key, run_key),
                    'steps': steps,
                    'counts': counts,
                    'ratios': ratios,
                    'lengths': lengths,
                    'accs': accs,
                })

            # 计算统一的 y 轴范围（跨所有 run）
            all_counts = []
            all_ratios = []
            all_lengths = []
            all_accs = []
            for run_data in all_runs_data:
                all_counts.extend(run_data['counts'])
                all_ratios.extend(run_data['ratios'])
                all_lengths.extend(run_data['lengths'])
                all_accs.extend(run_data['accs'])

            count_min, count_max = min(all_counts), max(all_counts)
            count_range = float(count_max - count_min)
            if abs(count_range) < 1e-9:
                base = max(abs(float(count_max)), 1.0)
                count_margin = base * 0.1
            else:
                count_margin = count_range * 0.1
            count_ylim = (max(0, count_min - count_margin),
                          count_max + count_margin)

            ratio_min, ratio_max = min(all_ratios), max(all_ratios)
            ratio_range = float(ratio_max - ratio_min)
            if abs(ratio_range) < 1e-9:
                base = max(abs(float(ratio_max)), 1.0)
                ratio_margin = base * 0.1
            else:
                ratio_margin = ratio_range * 0.1
            ratio_ylim = (max(0, ratio_min - ratio_margin),
                          min(1.0, ratio_max + ratio_margin))

            len_min, len_max = min(all_lengths), max(all_lengths)
            len_range = float(len_max - len_min)
            if abs(len_range) < 1e-9:
                base = max(abs(float(len_max)), 1.0)
                len_margin = base * 0.1
            else:
                len_margin = len_range * 0.1
            len_ylim = (max(0, len_min - len_margin), len_max + len_margin)

            acc_min, acc_max = min(all_accs), max(all_accs)
            acc_range = float(acc_max - acc_min)
            if abs(acc_range) < 1e-9:
                base = max(abs(float(acc_max)), 1.0)
                acc_margin = base * 0.1
            else:
                acc_margin = acc_range * 0.1
            acc_ylim = (max(0, acc_min - acc_margin),
                        min(1.0, acc_max + acc_margin))

            # 设置统一的样式函数
            def _style_ax(ax):
                ax.set_facecolor('#FFFFFF')  # 纯白背景，提高对比度
                for side in ['left', 'bottom', 'top', 'right']:
                    ax.spines[side].set_linewidth(2.0)  # 增加边框宽度
                    ax.spines[side].set_color('#000000')  # 黑色边框
                    ax.spines[side].set_visible(True)
                ax.tick_params(axis='both', labelcolor='#000000',
                               labelsize=FONT_TICK, length=8, width=2.0)  # 黑色刻度
                ax.grid(True, axis='both', alpha=0.25, color='#999999',
                        linewidth=1.0, linestyle='--', zorder=0)  # 虚线网格
                ax.set_axisbelow(True)
                ax.xaxis.set_major_locator(
                    plt.MaxNLocator(nbins=6, integer=True))
                ax.yaxis.set_major_locator(
                    plt.MaxNLocator(nbins=5, integer=False))

            # 辅助函数：绘制数据到子图
            def _plot_data_to_ax(ax, all_runs_data, data_key, ylim, ylabel, ylabel_color, formatter, show_legend=False):
                """绘制数据到指定的子图"""
                _style_ax(ax)
                handles = []
                for idx, run_data in enumerate(all_runs_data):
                    color = RUN_COLORS[idx % len(RUN_COLORS)]
                    data_smooth = _smooth_data(run_data[data_key])
                    steps = run_data['steps']

                    # 计算阴影带（基于平滑后的数据）
                    data_smooth_array = np.array(data_smooth)
                    window = min(5, len(data_smooth) //
                                 10) if len(data_smooth) > 10 else 1
                    if window > 1:
                        # 计算近似标准差（基于平滑后数据的变化率）
                        std_approx = np.convolve(
                            np.abs(np.diff(np.concatenate(
                                [[data_smooth[0]], data_smooth]))),
                            np.ones(window) / window,
                            mode='same'
                        ) * 1.2
                        # 绘制阴影带
                        ax.fill_between(
                            steps,
                            data_smooth_array - std_approx,
                            data_smooth_array + std_approx,
                            color=color,
                            alpha=0.2,
                            zorder=1,
                        )

                    # 先绘制白色背景线
                    ax.plot(steps, data_smooth, color='white',
                            linewidth=5.0, zorder=3, alpha=0.8)
                    # 再绘制主线条
                    line = ax.plot(
                        steps,
                        data_smooth,
                        color=color,
                        linewidth=4.0,
                        label=run_data['label'] if show_legend else '',
                        zorder=4,
                        alpha=1.0,
                        solid_capstyle='round',
                        solid_joinstyle='round'
                    )[0]
                    if show_legend:
                        handles.append(line)
                ax.set_xlabel('Training Step', fontsize=FONT_XLABEL,
                              fontweight='bold', labelpad=5, color='#000000')
                # 2x2模式下使用y轴标签，1x4模式下使用标题
                if layout_mode == '2x2':
                    ax.set_ylabel(ylabel, fontsize=FONT_TITLE,
                                  fontweight='bold', labelpad=10, color='#000000')
                else:
                    ax.set_title(ylabel, fontsize=FONT_TITLE,
                                 fontweight='bold', pad=10, color='#000000')
                ax.set_ylim(ylim)
                if formatter:
                    ax.yaxis.set_major_formatter(FuncFormatter(formatter))
                ax.tick_params(axis='y', labelcolor='#000000')
                return handles

            if layout_mode == '2x2':
                # 2x2 模式：2行2列布局
                # 调整尺寸使子图更方正（接近1:1高宽比）
                base_width = 18.0
                base_height = 14.0  # 增加高度，使子图更方正
                fig, axes = plt.subplots(2, 2, figsize=(
                    base_width, base_height), squeeze=True)
                ax_count, ax_ratio = axes[0, 0], axes[0, 1]
                ax_length, ax_acc = axes[1, 0], axes[1, 1]

                # 绘制四个子图（只在第一个子图显示图例）
                handles = _plot_data_to_ax(
                    ax_count, all_runs_data, 'counts', count_ylim,
                    r'$C_{\mathrm{context}}(\tau)$', COLOR_NGRAM, thousands_formatter, show_legend=True
                )
                _plot_data_to_ax(
                    ax_ratio, all_runs_data, 'ratios', ratio_ylim,
                    r'$R_{\mathrm{context}}(\tau)$', COLOR_NGRAM,
                    lambda x, pos: _ratio_percent_formatter_factory(ratio_ylim)(x, pos), show_legend=False
                )
                _plot_data_to_ax(
                    ax_length, all_runs_data, 'lengths', len_ylim,
                    '$L$', COLOR_LENGTH, thousands_formatter, show_legend=False
                )
                _plot_data_to_ax(
                    ax_acc, all_runs_data, 'accs', acc_ylim,
                    'Accuracy', COLOR_ACC, percent_formatter, show_legend=False
                )

                # 调整布局：优化间距，使子图更方正协调
                plt.subplots_adjust(wspace=0.25, hspace=0.35,
                                    top=0.92, bottom=0.20, left=0.10, right=0.95)

                # 添加图例
                num_runs = len(valid_run_keys)
                ncol = min(num_runs, 6)
                fig.legend(
                    handles, [h.get_label() for h in handles],
                    loc='lower center',
                    bbox_to_anchor=(0.5, 0.04),
                    ncol=ncol,
                    frameon=True,
                    framealpha=0.95,
                    edgecolor='#888888',
                    fancybox=True,
                    shadow=False,
                    prop={'weight': 'bold', 'size': FONT_LEGEND},
                )

            else:
                # 1x4 模式：1行4列布局
                # 保持子图长宽比为1:1，与2x2模式一致
                # 每个子图宽度增加，高度也相应增加
                subplot_width = 9.0
                subplot_height = subplot_width - 2.0  # 保持1:1长宽比
                fig, axes = plt.subplots(1, 4, figsize=(
                    subplot_width * 4, subplot_height), squeeze=True)
                ax_count, ax_ratio, ax_length, ax_acc = axes

                # 绘制四个子图（只在第一个子图显示图例）
                handles = _plot_data_to_ax(
                    ax_count, all_runs_data, 'counts', count_ylim,
                    r'$C_{\mathrm{context}}(\tau)$', COLOR_NGRAM, thousands_formatter, show_legend=True
                )
                _plot_data_to_ax(
                    ax_ratio, all_runs_data, 'ratios', ratio_ylim,
                    r'$R_{\mathrm{context}}(\tau)$', COLOR_NGRAM,
                    lambda x, pos: _ratio_percent_formatter_factory(ratio_ylim)(x, pos), show_legend=False
                )
                _plot_data_to_ax(
                    ax_length, all_runs_data, 'lengths', len_ylim,
                    '$L$', COLOR_LENGTH, thousands_formatter, show_legend=False
                )
                _plot_data_to_ax(
                    ax_acc, all_runs_data, 'accs', acc_ylim,
                    'Accuracy', COLOR_ACC, percent_formatter, show_legend=False
                )

                # 添加全局图例（根据 run 数量动态调整列数）
                num_runs = len(valid_run_keys)
                ncol = min(num_runs, 6)  # 最多 6 列，避免图例过宽
                fig.legend(
                    handles, [h.get_label() for h in handles],
                    loc='lower center',
                    bbox_to_anchor=(0.5, -0.15),  # 稍微下移，与底部留白配合
                    ncol=ncol,
                    frameon=True,
                    framealpha=0.95,
                    edgecolor='#888888',
                    fancybox=True,
                    shadow=False,
                    prop={'weight': 'bold', 'size': FONT_LEGEND},
                )

                # 调整布局：根据 run 数量调整底部留白（图例需要更多空间）
                bottom_margin = 0.16 + \
                    max(0, (num_runs - 2) * 0.02)  # 为图例留出足够空间
                # 调整子图间距，减小间距以增大子图
                # 子图间距减小，让子图占据更多空间
                plt.subplots_adjust(wspace=0.20, top=0.92,
                                    bottom=bottom_margin, left=0.07, right=0.97)

            save = f"two_runs_comparison_{group_name}_{n}gram_{layout_mode}.pdf"
            plt.savefig(
                os.path.join(output_dir, save),
                format='pdf',
                dpi=dpi,
                bbox_inches='tight',
                metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'},
            )
            plt.close(fig)
            print(
                f"已生成PDF图表: {os.path.join(output_dir, save)} (包含 {num_runs} 个 run)")


def plot_ngram_comparison_by_group(
    data_path,
    output_dir='plots',
    ngram_sizes=[6, 10, 15],
    dpi=600,
    max_step=None,
):
    """
    按 group 对比不同 ngram 的 local count 和 ratio 变化：
    - 每个 group 生成一张图
    - 第一排：不同 ngram 的 count（y轴统一）
    - 第二排：不同 ngram 的 ratio（y轴统一）
    - 每个子图显示该 group 中所有实验的曲线

    Args:
        data_path: JSON 数据文件路径（支持单个文件或文件列表）
        output_dir: 输出目录
        ngram_sizes: ngram 大小列表，例如 [6, 10, 15]
        dpi: PDF 清晰度
        max_step: 如果指定，只显示 step <= max_step 的数据点
    """
    # 兼容：既支持单个 json，也支持多个 json
    if isinstance(data_path, (list, tuple)):
        data_paths = list(data_path)
    else:
        data_paths = [data_path]

    all_results = _load_and_merge_results(data_paths)
    if not all_results:
        print("错误: 没有成功加载任何数据")
        return

    # 如果指定了 max_step，过滤所有结果
    if max_step is not None:
        print(f"截取数据到 step <= {max_step}")
        filtered_results = {}
        for exp_key, records in all_results.items():
            filtered_records = [
                r for r in records if r.get('step', 0) <= max_step]
            if filtered_records:
                filtered_results[exp_key] = filtered_records
        all_results = filtered_results
        if not all_results:
            print("错误: 过滤后没有剩余数据")
            return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 颜色配置
    COLOR_NGRAM = '#00468B'  # 深蓝色 (Diversity)

    # 多个 run 的颜色调色板
    RUN_COLORS = [
        '#00468B',  # 深蓝色
        '#9B59B6',  # 优雅紫色
        '#AE1029',  # 绯红色
        '#2ca02c',  # 明亮绿色
        '#ff7f0e',  # 鲜艳橙色
        '#9467bd',  # 明亮紫色
        '#8c564b',  # 棕色
        '#e377c2',  # 粉红色
        '#7f7f7f',  # 灰色
        '#bcbd22',  # 橄榄绿
    ]

    # 使用全局的 groups 配置
    groups = GROUPS

    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['axes.linewidth'] = 1.8

    # 字体大小
    FONT_TICK = 26
    FONT_XLABEL = 30
    FONT_YLABEL = 30
    FONT_TITLE = 36
    FONT_LEGEND = 32
    FONT_SUPTITLE = 42

    def _style_ax(ax):
        ax.set_facecolor('#FFFFFF')  # 纯白背景
        for side in ['left', 'bottom', 'top', 'right']:
            ax.spines[side].set_linewidth(2.0)
            ax.spines[side].set_color('#000000')
            ax.spines[side].set_visible(True)
        ax.tick_params(axis='both', labelcolor='#000000',
                       labelsize=FONT_TICK, length=8, width=2.0)
        ax.grid(True, axis='both', alpha=0.25, color='#999999',
                linewidth=1.0, linestyle='--', zorder=0)
        ax.set_axisbelow(True)
        ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=6, integer=True))
        ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))

    # 对每个 group 生成一张图
    for group_name, exp_keys in groups.items():
        # 过滤出存在的实验
        valid_keys = [k for k in exp_keys if k in all_results]
        if not valid_keys:
            print(f"警告: group={group_name} 没有可用的实验数据，跳过")
            continue

        # 为每个实验准备数据（所有 ngram 的 count 和 ratio）
        all_runs_data = {}
        for exp_key in valid_keys:
            res = sorted(all_results[exp_key], key=lambda x: x['step'])
            steps = [r['step'] for r in res]

            # 为每个 ngram 获取 count 和 ratio 数据
            ngram_data = {}
            for n in ngram_sizes:
                count_key, _, _ = _get_diversity_key_and_label(
                    res[0], n, 'count')
                ratio_key, _, _ = _get_diversity_key_and_label(
                    res[0], n, 'ratio')
                if count_key and ratio_key:
                    counts = [r.get(count_key, 0) for r in res]
                    ratios = [r.get(ratio_key, 0) for r in res]
                    ngram_data[n] = {
                        'counts': counts,
                        'ratios': ratios,
                    }
                else:
                    print(
                        f"警告: {exp_key} 缺少 {n}gram count 或 ratio 数据，跳过该 ngram")

            all_runs_data[exp_key] = {
                'label': MY_TITLES.get(exp_key, exp_key),
                'steps': steps,
                'ngram_data': ngram_data,
            }

        # 计算统一的 y 轴范围（跨所有 ngram 和所有实验）
        # 统一所有 count 的 y 轴范围
        all_counts = []
        all_ratios = []
        for exp_key, run_data in all_runs_data.items():
            for n in ngram_sizes:
                if n in run_data['ngram_data']:
                    all_counts.extend(run_data['ngram_data'][n]['counts'])
                    all_ratios.extend(run_data['ngram_data'][n]['ratios'])

        # 统一计算 count 的 y 轴范围
        if all_counts:
            count_min, count_max = min(all_counts), max(all_counts)
            count_range = float(count_max - count_min)
            if abs(count_range) < 1e-9:
                base = max(abs(float(count_max)), 1.0)
                count_margin = base * 0.1
            else:
                count_margin = count_range * 0.1
            unified_count_ylim = (
                max(0, count_min - count_margin), count_max + count_margin)
            unified_count_ylim = _add_top_headroom(
                unified_count_ylim, frac=0.08, cap_upper=None)
        else:
            unified_count_ylim = (0, 1000)

        # 统一计算 ratio 的 y 轴范围
        if all_ratios:
            ratio_min, ratio_max = min(all_ratios), max(all_ratios)
            ratio_range = float(ratio_max - ratio_min)
            if abs(ratio_range) < 1e-9:
                base = max(abs(float(ratio_max)), 1.0)
                ratio_margin = base * 0.1
            else:
                ratio_margin = ratio_range * 0.1
            unified_ratio_ylim = (
                max(0, ratio_min - ratio_margin), min(1.0, ratio_max + ratio_margin))
            unified_ratio_ylim = _add_top_headroom(
                unified_ratio_ylim, frac=0.08, cap_upper=1.0)
        else:
            unified_ratio_ylim = (0, 1.0)

        # 创建子图：2行N列布局（第一排count，第二排ratio）
        num_ngrams = len(ngram_sizes)
        subplot_width = 9.0
        subplot_height = subplot_width - 2.0
        fig, axes = plt.subplots(2, num_ngrams, figsize=(
            subplot_width * num_ngrams, subplot_height * 2), squeeze=True)

        # axes 是 2xN 的数组
        axes_count = axes[0]  # 第一排：count
        axes_ratio = axes[1]  # 第二排：ratio

        # 确保是数组
        if num_ngrams == 1:
            axes_count = [axes_count]
            axes_ratio = [axes_ratio]
        else:
            axes_count = list(axes_count)
            axes_ratio = list(axes_ratio)

        legend_handles = None

        # 辅助函数：绘制数据到子图
        def _plot_data_to_ax(ax, all_runs_data, n, data_key, ylim, ylabel, formatter, show_title=True, show_xlabel=True):
            """绘制数据到指定的子图"""
            _style_ax(ax)
            handles = []

            for exp_idx, exp_key in enumerate(valid_keys):
                run_data = all_runs_data[exp_key]
                if n not in run_data['ngram_data']:
                    continue

                color = RUN_COLORS[exp_idx % len(RUN_COLORS)]
                steps = run_data['steps']
                data = run_data['ngram_data'][n][data_key]

                # 平滑数据
                data_smooth = _smooth_data(data)

                # 计算阴影带
                data_smooth_array = np.array(data_smooth)
                window = min(5, len(data_smooth) //
                             10) if len(data_smooth) > 10 else 1
                if window > 1:
                    std_approx = np.convolve(
                        np.abs(np.diff(np.concatenate(
                            [[data_smooth[0]], data_smooth]))),
                        np.ones(window) / window,
                        mode='same'
                    ) * 1.2
                    if data_key == 'ratios':
                        std_approx = std_approx * 0.5  # ratio 的阴影带更小
                    ax.fill_between(
                        steps,
                        np.maximum(
                            0, data_smooth_array - std_approx) if data_key == 'ratios' else data_smooth_array - std_approx,
                        np.minimum(
                            1.0, data_smooth_array + std_approx) if data_key == 'ratios' else data_smooth_array + std_approx,
                        color=color,
                        alpha=0.2,
                        zorder=1,
                    )

                # 先绘制白色背景线
                ax.plot(steps, data_smooth, color='white',
                        linewidth=5.0, zorder=3, alpha=0.8)
                # 再绘制主线条
                line = ax.plot(
                    steps,
                    data_smooth,
                    color=color,
                    linewidth=4.0,
                    label=run_data['label'],
                    zorder=4,
                    alpha=1.0,
                    solid_capstyle='round',
                    solid_joinstyle='round'
                )[0]
                handles.append(line)

            # 设置子图标题和标签
            if show_title:
                ax.set_title(f'{n}-gram', fontsize=FONT_TITLE,
                             fontweight='bold', pad=10, color='#000000')
            if show_xlabel:
                ax.set_xlabel('Training Step', fontsize=FONT_XLABEL,
                              fontweight='bold', labelpad=5, color='#000000')
            ax.set_ylabel(ylabel, fontsize=FONT_YLABEL,
                          fontweight='bold', labelpad=10, color='#000000')
            ax.set_ylim(ylim)
            if formatter:
                ax.yaxis.set_major_formatter(FuncFormatter(formatter))
            ax.tick_params(axis='y', labelcolor='#000000')

            return handles

        # 绘制第一排：count（显示标题，不显示xlabel）
        for ax_idx, (ax, n) in enumerate(zip(axes_count, ngram_sizes)):
            handles = _plot_data_to_ax(
                ax, all_runs_data, n, 'counts', unified_count_ylim,
                r'$C_{\mathrm{context}}(\tau)$', thousands_formatter,
                show_title=True, show_xlabel=False
            )
            if legend_handles is None:
                legend_handles = handles

        # 绘制第二排：ratio（不显示标题，显示xlabel）
        for ax_idx, (ax, n) in enumerate(zip(axes_ratio, ngram_sizes)):
            _plot_data_to_ax(
                ax, all_runs_data, n, 'ratios', unified_ratio_ylim,
                r'$R_{\mathrm{context}}(\tau)$',
                lambda x, pos: _ratio_percent_formatter_factory(
                    unified_ratio_ylim)(x, pos),
                show_title=False, show_xlabel=True
            )

        # 添加组标题
        # fig.suptitle(group_name, fontsize=FONT_SUPTITLE, fontweight='bold', y=0.98)

        # 添加全局图例
        if legend_handles:
            num_runs = len(valid_keys)
            ncol = min(num_runs, 6)
            fig.legend(
                legend_handles, [h.get_label() for h in legend_handles],
                loc='lower center',
                bbox_to_anchor=(0.5, 0.01),
                ncol=ncol,
                frameon=True,
                framealpha=0.95,
                edgecolor='#888888',
                fancybox=True,
                shadow=False,
                prop={'weight': 'bold', 'size': FONT_LEGEND},
            )

        # 调整布局：增加行间距，避免重合
        bottom_margin = 0.18 + max(0, (len(valid_keys) - 2) * 0.02)
        plt.subplots_adjust(
            wspace=0.25,  # 列间距
            hspace=0.35,  # 行间距（增加以避免重合）
            top=0.92,     # 顶部留白
            bottom=bottom_margin,  # 底部留白（为图例留空间）
            left=0.08,
            right=0.97
        )

        # 保存图片
        ngram_str = '_'.join(map(str, ngram_sizes))
        save = f"ngram_comparison_{group_name.lower()}_{ngram_str}gram.pdf"
        plt.savefig(
            os.path.join(output_dir, save),
            format='pdf',
            dpi=dpi,
            bbox_inches='tight',
            metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'},
        )
        plt.close(fig)
        print(
            f"已生成PDF图表: {os.path.join(output_dir, save)} (包含 {len(valid_keys)} 个实验，{num_ngrams} 个 ngram)")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='按 groups 划分生成 count 和 ratio 图表')
    parser.add_argument('--input', '-i', type=str,
                        default='/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/analysis_results.json')
    parser.add_argument('--inputs', type=str, nargs='+',
                        help='多个分批收集的 JSON 文件路径（不同 experiment 可分散在不同文件）')
    parser.add_argument('--input-glob', type=str,
                        help='用通配符指定多个 JSON（例如: /path/to/batches/*.json）')
    parser.add_argument('--output-dir', '-o', type=str,
                        default='/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/plots', help='输出目录')
    parser.add_argument('--ngrams', type=int, nargs='+',
                        default=[10], help='要绘制的 n-gram 大小')
    parser.add_argument('--dpi', type=int, default=600, help='PDF清晰度（DPI）')
    parser.add_argument('--layout-mode', type=str, default='1x4', choices=['1x4', '2x2'],
                        help='布局模式：1x4（一行4列）或 2x2（2行2列）')
    parser.add_argument('--max-step', type=int, default=620,
                        help='如果指定，只显示 step <= max_step 的数据点（用于统一截取数据）')
    parser.add_argument('--plot-ngram-comparison', action='store_true',
                        help='是否生成不同 ngram 的对比图（按 group）')

    args = parser.parse_args()

    if args.inputs:
        data_paths = args.inputs
    elif args.input_glob:
        data_paths = sorted(glob.glob(args.input_glob))
        if not data_paths:
            print(f"错误: input-glob 没有匹配到任何文件: {args.input_glob}")
            exit(1)
    else:
        data_paths = args.input

    # 执行两个绘图函数
    print("=" * 60)
    print("生成按 groups 划分的 count 和 ratio 图表...")
    print("=" * 60)
    plot_two_figs_count_ratio_row(
        data_path=data_paths,
        output_dir=args.output_dir,
        ngram_sizes=args.ngrams,
        dpi=args.dpi,
        layout_mode=args.layout_mode,
        max_step=args.max_step,
    )

    print("\n" + "=" * 60)
    print("生成两个 run 的对比图...")
    print("=" * 60)
    plot_two_runs_comparison(
        data_path=data_paths,
        output_dir=args.output_dir,
        ngram_sizes=args.ngrams,
        dpi=args.dpi,
        max_step=args.max_step,
        layout_mode=args.layout_mode,
    )

    if args.plot_ngram_comparison:
        print("\n" + "=" * 60)
        print("生成不同 ngram 的对比图（按 group）...")
        print("=" * 60)
        plot_ngram_comparison_by_group(
            data_path=data_paths,
            output_dir=args.output_dir,
            ngram_sizes=args.ngrams,
            dpi=args.dpi,
            max_step=args.max_step,
        )
