import json
import os
import argparse
import glob
import matplotlib
matplotlib.use('Agg')  # 使用Agg后端以支持高质量PDF输出
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from matplotlib.ticker import FuncFormatter

# 尝试导入 scipy 用于平滑，如果没有则使用 numpy 实现
try:
    from scipy.signal import savgol_filter
    from scipy.stats import pearsonr
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False
    pearsonr = None

# 设置PDF输出参数
plt.rcParams['pdf.fonttype'] = 42  # TrueType字体，确保文字在PDF中可编辑
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['pdf.use14corefonts'] = False

# 全局配置：实验标题映射
MY_TITLES = {
    "baseline-dapo-math-redo": "GRPO High clip ratio",
    "baseline-grpo-dapo-math-minibsz32": "GRPO",
    "baseline-gspo-dapo-math-minibsz32": "GSPO",
    "skip-right-skip-limits10-gspo-dapo-math": "GSPO (Ours)",
    "skip-right-skip-limits10-dapo-math": "GRPO High clip ratio + ours",
    "skip-right-skip-limits10-grpo-dapo-math": "GRPO + ours",
    "skip-right-skip-limits10-gspo-dapo-math-add2k": "GSPO + ours (add2k)",
    "skip-right-skip-limits10-gspo-dapo-math-wo-repetition": "GSPO + Length",
    "skip-right-distinct-hard-gspo-dapo-math": "GSPO + ours (distinct ratio hard)",
    "skip-right-skip-limits10-gspo-dapo-math-wo-repetition-redo": "GSPO + $R_{\\text{len}}$",
    "llama-baseline-gspo-deepmath": "GSPO",
    "llama-add1k-gspo-deepmath": "GSPO (Ours)",
    "qwen3-4b-polaris-add1k-gspo": "GSPO (Ours)",
    "qwen3-4b-polaris-baseline-gspo": "GSPO",
    "skip-right-skip-limits10-gspo-dapo-math-add1k5": "$\Delta L = 1k$",
    "ours-gspo-dapo-math-add8k5": "$\Delta L = 8k$",
    "ours-gspo-dapo-math-add1k-fixed": f"$\Delta L = 500 + $ avg (fixed)",
    "ours-gspo-dapo-math-add600": "$\Delta L = 100$",
}

# 定义格式化函数，将大数字转换为k格式
def thousands_formatter(x, pos):
    """将数字格式化为k格式，500以上都用k"""
    return f'{x/1000:.1f}k'

def percent_formatter(x, pos):
    """将小数格式化为百分比数字（不带%）"""
    return f'{x*100:.0f}'

def _ratio_percent_formatter_factory(ylim):
    """
    ratio（通常在 0~1）用百分比展示时，若数值很小（例如 0.01~0.02），
    直接用整数百分比会导致大量刻度都显示成同一个数字（例如都显示 2）。
    这里根据当前 y 轴范围自适应保留小数位，避免"坐标刻度重复/怪异"。
    """
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
        span_pct = abs((y1 - y0) * 100.0)
        max_pct = max(abs(y0), abs(y1)) * 100.0
    except Exception:
        span_pct, max_pct = 0.0, 0.0

    # 经验规则：
    # - 如果最大值 < 10% 或范围 < 3%，用 1 位小数（2.1, 2.3）区分刻度
    # - 如果最大值 < 1% 或范围 < 0.5%，用 2 位小数（0.85, 0.92）
    if max_pct < 1.0 or span_pct < 0.5:
        fmt = '{:.2f}'
    elif max_pct < 10.0 or span_pct < 3.0:
        fmt = '{:.1f}'
    else:
        fmt = '{:.0f}'

    def _fmt(x, pos):
        try:
            return fmt.format(float(x) * 100.0)
        except Exception:
            return str(x)

    return _fmt

def _smooth_data(data, window_size=None, method='savgol'):
    """
    平滑数据曲线
    
    Args:
        data: 一维数组
        window_size: 平滑窗口大小（如果为 None，则自动计算）
        method: 'savgol' (Savitzky-Golay filter) 或 'moving_avg' (移动平均)
    
    Returns:
        平滑后的数据
    """
    data = np.array(data)
    if len(data) < 3:
        return data
    window_size = 5
    
    if method == 'savgol' and HAS_SCIPY:
        try:
            # Savitzky-Golay filter: 更平滑，保持峰值特征
            poly_order = min(3, window_size - 1)  # 多项式阶数
            return savgol_filter(data, window_size, poly_order)
        except Exception:
            # 如果失败，回退到移动平均
            method = 'moving_avg'
    
    if method == 'moving_avg':
        # 简单的移动平均
        kernel = np.ones(window_size) / window_size
        smoothed = np.convolve(data, kernel, mode='same')
        # 处理边界：保持首尾值不变
        half = window_size // 2
        smoothed[:half] = data[:half]
        smoothed[-half:] = data[-half:]
        return smoothed
    
    return data

def _add_top_headroom(ylim, frac: float = 0.08, cap_upper=None):
    """
    只给 y 轴上边界增加一点留白（headroom），避免左上角 annotation 框与曲线/阴影贴得太近。
    frac: 以当前 ylim span 的比例增加。
    cap_upper: 若给定（如 1.0），则上界不会超过该值。
    """
    try:
        y0, y1 = float(ylim[0]), float(ylim[1])
    except Exception:
        return ylim
    span = max(y1 - y0, 1e-9)
    y1_new = y1 + span * float(frac)
    if cap_upper is not None:
        try:
            y1_new = min(float(cap_upper), y1_new)
        except Exception:
            pass
    return (y0, y1_new)

def _pick_first_existing_key(record: dict, candidates):
    """在给定候选 key 列表中，返回第一个存在于 record 的 key；都不存在则返回 None。"""
    for k in candidates:
        if k in record:
            return k
    return None

def _get_diversity_key_and_label(record0: dict, n: int, diversity_metric: str):
    """
    返回 (key, y_label, kind)，其中 kind ∈ {'count','ratio'}，用于决定 y 轴范围与格式化。

    diversity_metric 支持：
    - 'count' / 'ratio'：原有 per-step 的 avg distinct n-gram
    - 'global_count' / 'global_ratio'：新增的全局 distinct 指标
    """
    if diversity_metric in ('count', 'ratio'):
        metric_suffix = diversity_metric
        key = f'distinct_{n}gram_{metric_suffix}' if f'distinct_{n}gram_{metric_suffix}' in record0 else f'{n}gram_{metric_suffix}'
        # 按 draft：Trajectory-level
        # C_distinct(τ) / R_distinct(τ)
        y_label = r'$C_{\mathrm{context}}(\tau)$' if diversity_metric == 'count' else r'$R_{\mathrm{context}}(\tau)$'
        kind = diversity_metric
        return key, y_label, kind

    if diversity_metric == 'global_count':
        key = _pick_first_existing_key(
            record0,
            [
                f'step_global_distinct_{n}gram_count'
            ],
        )
        # 按 draft：Global-level（记号与 ratio 一致）
        y_label = r'$C_{\mathrm{global}}(\mathcal{T})$'
        return key, y_label, 'count'

    if diversity_metric == 'global_ratio':
        key = _pick_first_existing_key(
            record0,
            [
                f'step_global_distinct_{n}gram_ratio'
            ],
        )
        y_label = r'$R_{\mathrm{global}}(\mathcal{T})$'
        return key, y_label, 'ratio'

    raise ValueError(f"Unknown diversity_metric: {diversity_metric}")

def _load_and_merge_results(data_paths):
    """加载多份 JSON 并合并（假设结构一致：{exp_name: [records...] }）。"""
    merged: dict = {}
    for p in data_paths:
        if not os.path.exists(p):
            print(f"警告: 找不到数据文件 {p}，跳过")
            continue
        try:
            with open(p, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"加载 JSON 失败: {p}，错误: {e}")
            continue

        if not isinstance(data, dict):
            print(f"警告: {p} 的顶层不是 dict（期望 {{exp_name: [..]}}），跳过")
            continue

        # 如果同名 experiment 冲突，则后加载的文件覆盖前面的（最简单且可预期）
        overlap = set(merged.keys()) & set(data.keys())
        if overlap:
            print(f"警告: {p} 与已有数据存在重复 experiment key，将覆盖: {sorted(list(overlap))}")
        merged.update(data)

    return merged

def plot_grpo_gspo_comparison(
    data_path,
    output_dir='plots',
    ngram_sizes=[10],
    dpi=600,
    grpo_key=None,
    gspo_key=None,
    max_step=None,
):
    """
    生成 GRPO 和 GSPO 的对比图，2x2 布局：
    - 左上：GRPO - Ccontext(τ) (count)
    - 右上：GRPO - Rcontext(τ) (ratio)
    - 左下：GSPO - Ccontext(τ) (count)
    - 右下：GSPO - Rcontext(τ) (ratio)
    
    每个子图显示三条线：
    - Diversity Metric (蓝色，左轴)
    - Response Length (紫色，右轴1)
    - Accuracy (红色，右轴2)
    
    Args:
        data_path: JSON 数据文件路径（或路径列表）
        output_dir: 输出目录
        ngram_sizes: n-gram 大小列表
        dpi: PDF 清晰度
        grpo_key: GRPO 实验的 key（如果为 None，则使用默认值）
        gspo_key: GSPO 实验的 key（如果为 None，则使用默认值）
        max_step: 如果指定，只显示 step <= max_step 的数据点
    """
    # 兼容：既支持单个 json，也支持多个 json
    if isinstance(data_path, (list, tuple)):
        data_paths = list(data_path)
    else:
        data_paths = [data_path]

    all_results = _load_and_merge_results(data_paths)
    if not all_results:
        print("错误: 没有成功加载任何数据")
        return

    # 如果指定了 max_step，过滤所有结果
    if max_step is not None:
        print(f"截取数据到 step <= {max_step}")
        filtered_results = {}
        for exp_key, records in all_results.items():
            filtered_records = [r for r in records if r.get('step', 0) <= max_step]
            if filtered_records:
                filtered_results[exp_key] = filtered_records
        all_results = filtered_results
        if not all_results:
            print("错误: 过滤后没有剩余数据")
            return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 默认的 GRPO 和 GSPO key
    if grpo_key is None:
        grpo_key = "baseline-grpo-dapo-math-minibsz32"
    if gspo_key is None:
        gspo_key = "baseline-gspo-dapo-math-minibsz32"

    # 检查 key 是否存在
    if grpo_key not in all_results:
        print(f"错误: 找不到 GRPO 实验 '{grpo_key}'")
        print(f"可用的实验: {sorted(all_results.keys())}")
        return
    if gspo_key not in all_results:
        print(f"错误: 找不到 GSPO 实验 '{gspo_key}'")
        print(f"可用的实验: {sorted(all_results.keys())}")
        return

    # 颜色配置
    COLOR_NGRAM = '#00468B'  # 深蓝色 (Diversity)
    COLOR_LENGTH = '#9B59B6' # 优雅紫色 (Length)
    COLOR_ACC = '#AE1029'    # 绯红色 (Accuracy)

    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['axes.linewidth'] = 1.8

    # 字体大小
    FONT_TICK = 34
    FONT_XLABEL = 40
    FONT_YLABEL = 30
    FONT_TITLE = 48
    FONT_LEGEND = 48

    def _style_ax(ax):
        ax.set_facecolor('#FFFFFF')  # 纯白背景
        for side in ['left', 'bottom', 'top', 'right']:
            ax.spines[side].set_linewidth(2.0)
            ax.spines[side].set_color('#000000')
            ax.spines[side].set_visible(True)
        ax.tick_params(axis='both', labelcolor='#000000', labelsize=FONT_TICK, length=8, width=2.0)
        ax.grid(True, axis='both', alpha=0.25, color='#999999', linewidth=1.0, linestyle='--', zorder=0)
        ax.set_axisbelow(True)
        ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=6, integer=True))
        ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))

    def _plot_one_subplot(
        ax,
        exp_key: str,
        n: int,
        metric: str,  # 'count' 或 'ratio'
        div_ylim,
        len_ylim,
        acc_ylim,
        show_len_axis: bool = True,  # 是否显示Length右轴
        show_acc_axis: bool = True,  # 是否显示Accuracy右轴
    ):
        """绘制单个子图，包含三条线：Diversity, Length, Accuracy"""
        res = sorted(all_results[exp_key], key=lambda x: x['step'])
        steps = [r['step'] for r in res]
        
        # 获取 diversity 数据
        div_key, div_label, div_kind = _get_diversity_key_and_label(res[0], n, metric)
        if not div_key:
            print(f"警告: {exp_key} 缺少 {metric} 对应 key，已用 0 填充")
            divs = [0 for _ in res]
        else:
            divs = [r.get(div_key, 0) for r in res]
        
        # 获取 length 和 accuracy 数据
        lengths = [r.get('avg_token_length', 0) for r in res]
        accs = [r.get('accuracy', 0) for r in res]
        
        # 计算皮尔森相关系数（只计算 count 的，使用原始数据，未平滑）
        corr_len_div = None
        corr_div_acc = None
        if metric == 'count' and HAS_SCIPY and pearsonr is not None:
            try:
                # 计算 (length, diversity) 的相关系数
                if len(lengths) == len(divs) and len(lengths) > 1:
                    lengths_array = np.array(lengths)
                    divs_array = np.array(divs)
                    # 移除无效值
                    valid_mask = np.isfinite(lengths_array) & np.isfinite(divs_array)
                    if np.sum(valid_mask) > 1:
                        corr_len_div, p_value_len_div = pearsonr(lengths_array[valid_mask], divs_array[valid_mask])
                
                # 计算 (diversity, accuracy) 的相关系数
                if len(divs) == len(accs) and len(divs) > 1:
                    divs_array = np.array(divs)
                    accs_array = np.array(accs)
                    # 移除无效值
                    valid_mask = np.isfinite(divs_array) & np.isfinite(accs_array)
                    if np.sum(valid_mask) > 1:
                        corr_div_acc, p_value_div_acc = pearsonr(divs_array[valid_mask], accs_array[valid_mask])
            except Exception as e:
                print(f"计算相关系数时出错: {e}")
        
        # 平滑数据
        divs_smooth = _smooth_data(divs)
        lengths_smooth = _smooth_data(lengths)
        accs_smooth = _smooth_data(accs)
        
        _style_ax(ax)
        # 设置左轴边框和刻度颜色为蓝色（Diversity Metric的颜色）
        ax.spines['left'].set_color(COLOR_NGRAM)
        ax.set_xlabel('Training Step', fontsize=FONT_XLABEL, fontweight='bold', labelpad=12, color='#000000')
        # 将metric作为标题显示，不显示ylabel
        ax.set_title(div_label, fontsize=FONT_TITLE, fontweight='bold', pad=14, color='#000000')
        
        # 绘制 Diversity Metric (左轴，蓝色)
        divs_array = np.array(divs_smooth)
        window = min(5, len(divs_smooth) // 10) if len(divs_smooth) > 10 else 1
        if window > 1:
            std_approx = np.convolve(
                np.abs(np.diff(np.concatenate([[divs_smooth[0]], divs_smooth]))),
                np.ones(window) / window,
                mode='same'
            ) * 1.2
            ax.fill_between(
                steps,
                divs_array - std_approx,
                divs_array + std_approx,
                color=COLOR_NGRAM,
                alpha=0.2,
                zorder=1,
            )
        
        # 先绘制白色背景线
        ax.plot(steps, divs_smooth, color='white', linewidth=5.0, zorder=3, alpha=0.8)
        # 再绘制主线条
        ln_div = ax.plot(
            steps,
            divs_smooth,
            color=COLOR_NGRAM,
            linewidth=4.0,
            linestyle='-',
            alpha=1.0,
            zorder=4,
            label='Diversity Metric',
            solid_capstyle='round',
            solid_joinstyle='round',
        )[0]
        
        if div_ylim is not None:
            ax.set_ylim(div_ylim)
        
        # y 轴格式化，左轴刻度颜色设为蓝色（Diversity Metric的颜色）
        if div_kind == 'ratio':
            ax.yaxis.set_major_formatter(FuncFormatter(_ratio_percent_formatter_factory(div_ylim if div_ylim else (0, 1))))
        else:
            ax.yaxis.set_major_formatter(FuncFormatter(thousands_formatter))
        ax.tick_params(axis='y', labelcolor=COLOR_NGRAM, labelsize=FONT_TICK, length=8, width=2.0)
        
        # 设置 x 轴范围（确保在获取坐标轴范围之前设置）
        # 在左右两侧添加空隙（约5%的边距）
        if steps:
            x_min = min(steps)
            x_max = max(steps)
            x_range = x_max - x_min
            x_margin = x_range * 0.05  # 左右各留5%的空隙
            ax.set_xlim(x_min - x_margin, x_max + x_margin)
        
        # --- 右轴1：Length ---
        ax2 = ax.twinx()
        ax2.spines['top'].set_linewidth(2.0)
        ax2.spines['left'].set_visible(False)
        ax2.spines['top'].set_color('#000000')
        ax2.spines['right'].set_linewidth(2.0)
        ax2.spines['right'].set_color(COLOR_LENGTH if show_len_axis else '#000000')
        
        lengths_array = np.array(lengths_smooth)
        window = min(5, len(lengths_smooth) // 10) if len(lengths_smooth) > 10 else 1
        if window > 1:
            std_approx = np.convolve(
                np.abs(np.diff(np.concatenate([[lengths_smooth[0]], lengths_smooth]))),
                np.ones(window) / window,
                mode='same'
            ) * 1.2
            ax2.fill_between(
                steps,
                lengths_array - std_approx,
                lengths_array + std_approx,
                color=COLOR_LENGTH,
                alpha=0.2,
                zorder=2,
            )
        
        # 先绘制白色背景线
        ax2.plot(steps, lengths_smooth, color='white', linewidth=5.0, zorder=3, alpha=0.8)
        # 再绘制主线条
        ln_len = ax2.plot(
            steps,
            lengths_smooth,
            color=COLOR_LENGTH,
            linewidth=4.0,
            alpha=1.0,
            linestyle='-',
            zorder=5,
            label='Response Length',
            solid_capstyle='round',
            solid_joinstyle='round',
        )[0]
        ax2.set_ylim(len_ylim)
        ax2.yaxis.set_major_formatter(FuncFormatter(thousands_formatter))
        ax2.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))
        
        if show_len_axis:
            ax2.tick_params(axis='y', labelcolor=COLOR_LENGTH, labelsize=FONT_TICK, length=8, width=2.0)
        else:
            ax2.set_yticklabels([])
            ax2.tick_params(axis='y', right=False, left=False, labelleft=False, labelright=False)
            ax2.spines['right'].set_visible(False)
        
        # --- 右轴2：Accuracy ---
        ax3 = ax.twinx()
        # 将 ax3 的右轴稍微偏移，避免与 ax2 重叠
        if show_len_axis:
            ax3.spines['right'].set_position(('outward', 60))
        ax3.spines['top'].set_linewidth(2.0)
        ax3.spines['left'].set_visible(False)
        ax3.spines['top'].set_color('#000000')
        ax3.spines['right'].set_linewidth(2.0)
        ax3.spines['right'].set_color(COLOR_ACC if show_acc_axis else '#000000')
        
        accs_array = np.array(accs_smooth)
        window = min(5, len(accs_smooth) // 10) if len(accs_smooth) > 10 else 1
        if window > 1:
            std_approx = np.convolve(
                np.abs(np.diff(np.concatenate([[accs_smooth[0]], accs_smooth]))),
                np.ones(window) / window,
                mode='same'
            ) * 0.5
            ax3.fill_between(
                steps,
                np.maximum(0, accs_array - std_approx),
                np.minimum(1.0, accs_array + std_approx),
                color=COLOR_ACC,
                alpha=0.2,
                zorder=3,
            )
        
        # 先绘制白色背景线
        ax3.plot(steps, accs_smooth, color='white', linewidth=5.0, zorder=3, alpha=0.8)
        # 再绘制主线条
        ln_acc = ax3.plot(
            steps,
            accs_smooth,
            color=COLOR_ACC,
            linewidth=4.0,
            alpha=1.0,
            linestyle='-',
            zorder=6,
            label='Accuracy',
            solid_capstyle='round',
            solid_joinstyle='round',
        )[0]
        ax3.set_ylim(acc_ylim)
        ax3.yaxis.set_major_formatter(FuncFormatter(percent_formatter))
        ax3.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))
        
        if show_acc_axis:
            ax3.tick_params(axis='y', labelcolor=COLOR_ACC, labelsize=FONT_TICK, length=8, width=2.0)
        else:
            ax3.set_yticklabels([])
            ax3.tick_params(axis='y', right=False, left=False, labelleft=False, labelright=False)
            ax3.spines['right'].set_visible(False)
        
        # 在左上角显示皮尔森相关系数（在所有坐标轴设置完成后）
        if corr_len_div is not None or corr_div_acc is not None:
            # 获取坐标轴范围（在设置完所有 ylim 后）
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            x_range = xlim[1] - xlim[0]
            y_range = ylim[1] - ylim[0]
            
            # 计算文本位置（左上角，留出一些边距）
            text_x = xlim[0] + x_range * 0.05  # 增加右边距，让框往右移
            text_y = ylim[1] - y_range * 0.05
            
            # 构建文本内容（使用 ρ 符号）
            text_lines = []
            if corr_len_div is not None:
                text_lines.append(f'$\\rho_{{L,C}} = {corr_len_div:.3f}$')
            if corr_div_acc is not None:
                text_lines.append(f'$\\rho_{{C,Acc}} = {corr_div_acc:.3f}$')
            
            if text_lines:
                text_str = '\n'.join(text_lines)
                # 添加文本背景框
                ax.text(
                    text_x, text_y, text_str,
                    transform=ax.transData,
                    fontsize=34,
                    fontweight='bold',
                    verticalalignment='top',
                    horizontalalignment='left',
                    bbox=dict(
                        boxstyle='round,pad=0.5',
                        facecolor='white',
                        edgecolor='black',
                        alpha=0.9,
                        linewidth=2.0
                    ),
                    zorder=100
                )
        
        return ln_div, ln_len, ln_acc

    for n in ngram_sizes:
        # 计算统一的 y 轴范围（跨 GRPO 和 GSPO）
        all_counts = []
        all_ratios = []
        all_lengths = []
        all_accs = []
        
        for exp_key in [grpo_key, gspo_key]:
            res = sorted(all_results[exp_key], key=lambda x: x['step'])
            count_key, _, _ = _get_diversity_key_and_label(res[0], n, 'count')
            ratio_key, _, _ = _get_diversity_key_and_label(res[0], n, 'ratio')
            if count_key:
                all_counts.extend([r.get(count_key, 0) for r in res])
            if ratio_key:
                all_ratios.extend([r.get(ratio_key, 0) for r in res])
            all_lengths.extend([r.get('avg_token_length', 0) for r in res])
            all_accs.extend([r.get('accuracy', 0) for r in res])
        
        # 计算 count y 轴范围
        if all_counts:
            count_min, count_max = min(all_counts), max(all_counts)
            count_range = float(count_max - count_min)
            if abs(count_range) < 1e-9:
                base = max(abs(float(count_max)), 1.0)
                count_margin = base * 0.1
            else:
                count_margin = count_range * 0.1
            count_ylim = (max(0, count_min - count_margin), count_max + count_margin)
            count_ylim = _add_top_headroom(count_ylim, frac=0.08, cap_upper=None)
        else:
            count_ylim = (0, 1500)
        
        # 计算 ratio y 轴范围
        if all_ratios:
            ratio_min, ratio_max = min(all_ratios), max(all_ratios)
            ratio_range = float(ratio_max - ratio_min)
            if abs(ratio_range) < 1e-9:
                base = max(abs(float(ratio_max)), 1.0)
                ratio_margin = base * 0.1
            else:
                ratio_margin = ratio_range * 0.1
            ratio_ylim = (max(0, ratio_min - ratio_margin), min(1.0, ratio_max + ratio_margin))
            ratio_ylim = _add_top_headroom(ratio_ylim, frac=0.08, cap_upper=1.0)
        else:
            ratio_ylim = (0, 1.0)
        
        # 计算 length y 轴范围
        if all_lengths:
            len_min, len_max = min(all_lengths), max(all_lengths)
            len_range = float(len_max - len_min)
            if abs(len_range) < 1e-9:
                base = max(abs(float(len_max)), 1.0)
                len_margin = base * 0.1
            else:
                len_margin = len_range * 0.1
            len_ylim = (max(0, len_min - len_margin), len_max + len_margin)
            len_ylim = _add_top_headroom(len_ylim, frac=0.08, cap_upper=None)
        else:
            len_ylim = (0, 4000)
        
        # 计算 accuracy y 轴范围
        if all_accs:
            acc_min, acc_max = min(all_accs), max(all_accs)
            acc_range = float(acc_max - acc_min)
            if abs(acc_range) < 1e-9:
                base = max(abs(float(acc_max)), 1.0)
                acc_margin = base * 0.1
            else:
                acc_margin = acc_range * 0.1
            acc_ylim = (max(0, acc_min - acc_margin), min(1.0, acc_max + acc_margin))
            acc_ylim = _add_top_headroom(acc_ylim, frac=0.08, cap_upper=1.0)
        else:
            acc_ylim = (0, 1.0)
        
        # 创建 1x4 子图布局，手动控制不同间距
        # 从左到右：GRPO - Ccontext(τ), GRPO - Rcontext(τ), GSPO - Ccontext(τ), GSPO - Rcontext(τ)
        subplot_width = 9.0
        subplot_height = 10  # 保持1:1长宽比
        
        # 创建figure，不使用gridspec，直接手动创建axes
        fig = plt.figure(figsize=(subplot_width * 6, subplot_height))
        
        # 间距参数（使用 figure 坐标，0-1之间，相对较小的值）
        small_gap = 0.04  # 第1-2张和第3-4张之间的间距（小）
        large_gap = 0.06  # 第2-3张之间的间距（大）
        
        # 计算每个子图的宽度和位置
        # 增加左右边距，确保y轴标签和刻度有足够空间
        left_margin = 0.12  # 增加左边距，为左y轴标签留出空间 - 增加空隙
        right_margin = 0.10  # 增加右边距，为右y轴标签（Length和Accuracy）留出空间 - 增加空隙
        available_width = 1.0 - left_margin - right_margin
        total_gaps = small_gap * 2 + large_gap
        subplot_width_unit = (available_width - total_gaps) / 4.0
        
        # 设置每个子图的位置 [left, bottom, width, height]
        bottom = 0.12
        height = 0.73
        
        # 计算每个子图的左边界位置
        pos_0_left = left_margin
        pos_1_left = left_margin + subplot_width_unit + small_gap
        pos_2_left = left_margin + subplot_width_unit * 2 + small_gap + large_gap
        pos_3_left = left_margin + subplot_width_unit * 3 + small_gap + large_gap + small_gap
        
        # 直接创建axes并设置位置
        axes = []
        axes.append(fig.add_axes([pos_0_left, bottom, subplot_width_unit, height]))
        axes.append(fig.add_axes([pos_1_left, bottom, subplot_width_unit, height]))
        axes.append(fig.add_axes([pos_2_left, bottom, subplot_width_unit, height]))
        axes.append(fig.add_axes([pos_3_left, bottom, subplot_width_unit, height]))
        
        # 保存位置信息用于后续标注
        subplot_positions = {
            'left_margin': left_margin,
            'subplot_width': subplot_width_unit,
            'small_gap': small_gap,
            'large_gap': large_gap,
            'pos_0_left': pos_0_left,
            'pos_1_left': pos_1_left,
            'pos_2_left': pos_2_left,
            'pos_3_left': pos_3_left,
        }
        
        # 第1列：GRPO - Ccontext(τ) (count) - 右边显示Length
        ax_grpo_count = axes[0]
        ln_div1, ln_len1, ln_acc1 = _plot_one_subplot(
            ax_grpo_count, grpo_key, n, 'count', count_ylim, len_ylim, acc_ylim,
            show_len_axis=True, show_acc_axis=False
        )
        
        # 第2列：GRPO - Rcontext(τ) (ratio) - 右边显示Accuracy
        ax_grpo_ratio = axes[1]
        _plot_one_subplot(
            ax_grpo_ratio, grpo_key, n, 'ratio', ratio_ylim, len_ylim, acc_ylim,
            show_len_axis=False, show_acc_axis=True
        )
        
        # 第3列：GSPO - Ccontext(τ) (count) - 右边显示Length
        ax_gspo_count = axes[2]
        _plot_one_subplot(
            ax_gspo_count, gspo_key, n, 'count', count_ylim, len_ylim, acc_ylim,
            show_len_axis=True, show_acc_axis=False
        )
        
        # 第4列：GSPO - Rcontext(τ) (ratio) - 右边显示Accuracy
        ax_gspo_ratio = axes[3]
        _plot_one_subplot(
            ax_gspo_ratio, gspo_key, n, 'ratio', ratio_ylim, len_ylim, acc_ylim,
            show_len_axis=False, show_acc_axis=True
        )
        
        # 添加垂直虚线分隔 GRPO（前2列）和 GSPO（后2列）
        # 在第2列和第3列之间添加分隔线
        # 使用保存的位置信息计算分隔线位置
        pos_1_right = subplot_positions['pos_1_left'] + subplot_positions['subplot_width']
        pos_2_left = subplot_positions['pos_2_left']
        x_sep = (pos_1_right + pos_2_left) / 2.0
        y0 = bottom - 0.05
        y1 = bottom + height + 0.05
        line = plt.Line2D(
            [x_sep, x_sep],
            [y0, y1],
            transform=fig.transFigure,
            linestyle='--',
            linewidth=5,
            color='#666666',
            alpha=1.0,
            zorder=10,
        )
        fig.add_artist(line)
        
        # 添加 GRPO 和 GSPO 标注
        FONT_GROUP_TITLE = 48
        # GRPO标注：在前两列上方居中
        pos_0_left = subplot_positions['pos_0_left']
        pos_1_right = pos_1_right
        grpo_x_center = (pos_0_left + pos_1_right) / 2.0
        grpo_y = bottom + height + 0.075
        fig.text(grpo_x_center, grpo_y, 'GRPO', ha='center', va='bottom', 
                fontsize=FONT_GROUP_TITLE, fontweight='bold', color='#000000')
        
        # GSPO标注：在后两列上方居中
        pos_2_left = pos_2_left
        pos_3_right = subplot_positions['pos_3_left'] + subplot_positions['subplot_width']
        gspo_x_center = (pos_2_left + pos_3_right) / 2.0
        gspo_y = bottom + height + 0.075
        fig.text(gspo_x_center, gspo_y, 'GSPO', ha='center', va='bottom', 
                fontsize=FONT_GROUP_TITLE, fontweight='bold', color='#000000')
        
        # 添加全局图例
        fig.legend(
            [ln_div1, ln_len1, ln_acc1],
            ['Metric', 'Response Length', 'Accuracy'],
            loc='lower center',
            bbox_to_anchor=(0.5, -0.2),
            ncol=3,
            frameon=True,
            framealpha=0.95,
            edgecolor='#888888',
            fancybox=True,
            shadow=False,
            prop={'weight': 'bold', 'size': FONT_LEGEND},
        )
        
        # 布局已通过手动设置位置完成，不需要 subplots_adjust
        # 位置已经在创建axes时设置好了
        
        save = f"grpo_gspo_comparison_{n}gram.pdf"
        plt.savefig(
            os.path.join(output_dir, save),
            format='pdf',
            dpi=dpi,
            bbox_inches='tight',
            pad_inches=0.1,  # 保留边距，避免裁剪掉y轴标签
            metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'},
        )
        plt.close(fig)
        print(f"已生成PDF图表: {os.path.join(output_dir, save)}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='生成 GRPO 和 GSPO 的对比图')
    parser.add_argument('--input', '-i', type=str, 
                        default='/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/analysis_results.json')
    parser.add_argument('--inputs', type=str, nargs='+',
                        help='多个分批收集的 JSON 文件路径（不同 experiment 可分散在不同文件）')
    parser.add_argument('--input-glob', type=str,
                        help='用通配符指定多个 JSON（例如: /path/to/batches/*.json）')
    parser.add_argument('--output-dir', '-o', type=str, 
                        default='/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/plots', 
                        help='输出目录')
    parser.add_argument('--ngrams', type=int, nargs='+', default=[10], help='要绘制的 n-gram 大小')
    parser.add_argument('--dpi', type=int, default=600, help='PDF清晰度（DPI）')
    parser.add_argument('--grpo-key', type=str, default=None,
                        help='GRPO 实验的 key（默认: baseline-grpo-dapo-math-minibsz32）')
    parser.add_argument('--gspo-key', type=str, default=None,
                        help='GSPO 实验的 key（默认: baseline-gspo-dapo-math-minibsz32）')
    parser.add_argument('--max-step', type=int, default=None,
                        help='如果指定，只显示 step <= max_step 的数据点（用于统一截取数据）')
    
    args = parser.parse_args()

    if args.inputs:
        data_paths = args.inputs
    elif args.input_glob:
        data_paths = sorted(glob.glob(args.input_glob))
        if not data_paths:
            print(f"错误: input-glob 没有匹配到任何文件: {args.input_glob}")
            exit(1)
    else:
        data_paths = args.input

    print("=" * 60)
    print("生成 GRPO 和 GSPO 的对比图...")
    print("=" * 60)
    plot_grpo_gspo_comparison(
        data_path=data_paths,
        output_dir=args.output_dir,
        ngram_sizes=args.ngrams,
        dpi=args.dpi,
        grpo_key=args.grpo_key,
        gspo_key=args.gspo_key,
        max_step=args.max_step,
    )
