import json
import os
import glob
import matplotlib
matplotlib.use('Agg')  # 使用Agg后端以支持高质量PDF输出
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
from matplotlib.ticker import FuncFormatter

# 设置PDF输出参数
plt.rcParams['pdf.fonttype'] = 42  # TrueType字体，确保文字在PDF中可编辑
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['pdf.use14corefonts'] = False

# 定义格式化函数，将大数字转换为k格式
def thousands_formatter(x, pos):
    """将数字格式化为k格式"""
    return f'{x/1000:.1f}k'

# 定义百分比格式化函数
def percent_formatter(x, pos):
    """将数字格式化为百分比"""
    return f'{x:.0f}'

def read_jsonl(file_path):
    """读取jsonl文件"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data

def extract_step_number(filename):
    """从文件名中提取step数字"""
    import re
    match = re.search(r'step_(\d+)_traindata', filename)
    if match:
        return int(match.group(1))
    return None

def analyze_training_data(data_dir, max_length=8192):
    """分析训练数据中的错误case"""
    pattern = os.path.join(data_dir, 'step_*_traindata.jsonl')
    files = sorted(glob.glob(pattern))
    
    if not files:
        print(f"警告: 在 {data_dir} 中未找到训练数据文件")
        return None
    
    results = {}
    
    for file_path in files:
        step = extract_step_number(os.path.basename(file_path))
        if step is None:
            continue
        
        print(f"正在处理: {os.path.basename(file_path)}")
        
        data = read_jsonl(file_path)
        
        # 统计错误输出的长度
        incorrect_lengths = []
        count_below_max = 0
        
        for item in data:
            final_scores = item.get('final_scores', [])
            response_length = item.get('response_length', [])
            
            # 确保final_scores和response_length都是列表且长度一致
            if not isinstance(final_scores, list) or not isinstance(response_length, list):
                continue
            
            if len(final_scores) != len(response_length):
                continue
            
            # 遍历每条输出，找出错误的
            for score, length in zip(final_scores, response_length):
                if score < 1.0:  # 错误的输出
                    incorrect_lengths.append(length)
                    if length < max_length:
                        count_below_max += 1
        
        if incorrect_lengths:
            avg_incorrect_length = np.mean(incorrect_lengths)
            ratio_below_max = count_below_max / len(incorrect_lengths)
            
            results[step] = {
                'avg_incorrect_length': avg_incorrect_length,
                'ratio_below_max': ratio_below_max,
                'total_incorrect': len(incorrect_lengths)
            }
            
            print(f"  Step {step}: 错误输出数={len(incorrect_lengths)}, "
                  f"错误平均长度={avg_incorrect_length:.2f}, "
                  f"未到达最大长度比例={ratio_below_max:.2%}")
        else:
            print(f"  Step {step}: 没有错误输出")
    
    return results

def plot_combined_results(all_results, output_path, dpi=600):
    """绘制并排的双y轴图表"""
    
    # 实验名称映射
    exp_titles = {
        "baseline-grpo-dapo-math-minibsz32": "GRPO",
        "baseline-gspo-dapo-math-minibsz32": "GSPO"
    }
    
    # 颜色配置 - 参考 plot_length_ngram.py
    COLOR_INCORRECT = '#00468B'  # 深蓝色 (错误case平均长度)
    COLOR_RATIO = '#AE1029'      # 绯红色 (比例)
    COLOR_MAX = '#9B59B6'        # 紫色 (最大长度线)
    
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['axes.linewidth'] = 1.8
    
    # 创建子图
    exp_names = list(all_results.keys())
    num_cols = len(exp_names)
    fig, axes = plt.subplots(1, num_cols, figsize=(7.0 * num_cols, 6.5), squeeze=False)
    
    # === 第一步：收集所有数据，计算统一的y轴范围 ===
    all_incorrect_lengths = []
    all_ratios = []
    
    for exp_name in exp_names:
        results = all_results[exp_name]
        steps = sorted(results.keys())
        all_incorrect_lengths.extend([results[s]['avg_incorrect_length'] for s in steps])
        all_ratios.extend([results[s]['ratio_below_max'] * 100 for s in steps])
    
    # 计算统一范围，确保包含8192这条线
    len_min, len_max = min(all_incorrect_lengths), max(all_incorrect_lengths)
    # 确保上限至少包含8192
    len_max = max(len_max, 8192)
    len_margin = (len_max - len_min) * 0.15
    len_margin_low = (len_max - len_min) * 0.05
    len_ylim = (max(0, len_min - len_margin_low), len_max + len_margin)
    
    # 比例固定为 0-100
    ratio_ylim = (0, 100)
    
    # === 第二步：绘制图表 ===
    for i, exp_name in enumerate(exp_names):
        ax1 = axes[0, i]
        results = all_results[exp_name]
        steps = sorted(results.keys())
        avg_incorrect_lengths = [results[s]['avg_incorrect_length'] for s in steps]
        ratios = [results[s]['ratio_below_max'] * 100 for s in steps]
        
        # --- 1. 绘制平均长度 (左轴) ---
        ax1.set_title(exp_titles.get(exp_name, exp_name), fontsize=28, fontweight='bold', pad=20)
        ax1.set_xlabel('Training Step', fontsize=20, fontweight='bold', labelpad=10)
        
        if i == 0:
            ax1.set_ylabel('Average Incorrect Length', color=COLOR_INCORRECT, fontsize=20, fontweight='bold', labelpad=10)
        else:
            ax1.tick_params(left=True, labelleft=False, length=6, width=1.5, labelsize=16)
        
        # 设置边框线粗细
        ax1.spines['left'].set_linewidth(1.8)
        ax1.spines['bottom'].set_linewidth(1.8)
        ax1.spines['top'].set_linewidth(1.8)
        ax1.spines['right'].set_linewidth(1.8)
        ax1.spines['top'].set_visible(True)
        ax1.spines['right'].set_visible(True)
        ax1.spines['left'].set_color('#666666')
        ax1.spines['bottom'].set_color('#666666')
        ax1.spines['top'].set_color('#666666')
        ax1.spines['right'].set_color('#666666')
        ax1.set_facecolor('#FAFAFA')
        
        # 1.1 绘制错误case平均长度 (深蓝色)
        avg_incorrect_array = np.array(avg_incorrect_lengths)
        window = min(5, len(avg_incorrect_lengths) // 10) if len(avg_incorrect_lengths) > 10 else 1
        if window > 1:
            std_approx = np.convolve(np.abs(np.diff(np.concatenate([[avg_incorrect_lengths[0]], avg_incorrect_lengths]))), 
                                    np.ones(window)/window, mode='same') * 1.2
            ax1.fill_between(steps, avg_incorrect_array - std_approx, avg_incorrect_array + std_approx, 
                            color=COLOR_INCORRECT, alpha=0.2, zorder=1)
        
        ln1 = ax1.plot(steps, avg_incorrect_lengths, color=COLOR_INCORRECT, linewidth=3.0, 
                       label='Avg Incorrect Length', linestyle='-', alpha=1.0, zorder=4)
        
        # 1.2 添加 8192 固定线
        ln_max = ax1.axhline(y=8192, color=COLOR_MAX, linestyle='--', linewidth=2.5, 
                          label='Max Length (8192)', alpha=0.8, zorder=3)
        
        ax1.tick_params(axis='both', labelcolor='#333333', labelsize=16, length=6, width=1.5)
        ax1.grid(True, axis='both', alpha=0.3, color='#CCCCCC', linewidth=0.8, linestyle='-', zorder=0)
        ax1.set_ylim(len_ylim)
        ax1.set_axisbelow(True)
        ax1.yaxis.set_major_formatter(FuncFormatter(thousands_formatter))
        ax1.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))
        ax1.xaxis.set_major_locator(plt.MaxNLocator(nbins=6, integer=True))
        
        # --- 2. 绘制比例 (右轴 - 绯红色) ---
        ax2 = ax1.twinx()
        ax2.spines['top'].set_linewidth(1.8)
        ax2.spines['left'].set_visible(False)
        
        if i == num_cols - 1:  # 最后一个图显示比例的刻度
            ax2.set_ylabel('Ratio Below Max Length (%)', color=COLOR_RATIO, fontsize=20, fontweight='bold', labelpad=8)
            ax2.spines['right'].set_linewidth(1.8)
            ax2.spines['right'].set_color('#666666')
            ax2.tick_params(axis='y', labelcolor=COLOR_RATIO, labelsize=16, length=6, width=1.5)
        else:
            ax2.set_yticklabels([])
            ax2.tick_params(axis='y', right=False, left=False, labelleft=False, labelright=False)
            ax2.spines['right'].set_visible(False)
        ax2.spines['top'].set_color('#666666')
        
        # 添加半透明阴影区域
        ratios_array = np.array(ratios)
        window = min(5, len(ratios) // 10) if len(ratios) > 10 else 1
        if window > 1:
            std_approx = np.convolve(np.abs(np.diff(np.concatenate([[ratios[0]], ratios]))), 
                                    np.ones(window)/window, mode='same') * 1.2
            ax2.fill_between(steps, np.maximum(0, ratios_array - std_approx), 
                            np.minimum(100, ratios_array + std_approx), 
                            color=COLOR_RATIO, alpha=0.2, zorder=6)
        
        ln_ratio = ax2.plot(steps, ratios, color=COLOR_RATIO, linewidth=3.0,
                       label='Ratio Below Max Length', linestyle='-', alpha=1.0, zorder=7)
        
        ax2.set_ylim(ratio_ylim)
        ax2.yaxis.set_major_formatter(FuncFormatter(percent_formatter))
        ax2.yaxis.set_major_locator(plt.MaxNLocator(nbins=5, integer=False))
        
        # 保存第一个子图的线条用于图例
        if i == 0:
            all_lns = ln1 + [ln_max] + ln_ratio
            all_labs = [l.get_label() for l in all_lns]
    
    # 创建全局图例，放在底部
    legend = fig.legend(all_lns, all_labs, loc='lower center', 
               bbox_to_anchor=(0.5, -0.15), 
               ncol=3, frameon=True, framealpha=0.95, 
               edgecolor='#888888', fancybox=True, shadow=False, 
               prop={'weight': 'bold', 'size': 24})
    legend.get_frame().set_linewidth(1.5)
    
    # 调整布局
    plt.subplots_adjust(wspace=0.2, top=0.85, bottom=0.10, left=0.08, right=0.92)
    
    # 保存图表
    plt.savefig(output_path, format='pdf', dpi=dpi, bbox_inches='tight',
               metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'})
    plt.close()
    print(f"\n高质量PDF图表已保存至: {output_path}")

def main():
    # 设置基础路径
    base_path = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct"
    
    # 定义两个实验路径
    experiments = {
        "baseline-grpo-dapo-math-minibsz32": os.path.join(base_path, "baseline-grpo-dapo-math-minibsz32/training_data"),
        "baseline-gspo-dapo-math-minibsz32": os.path.join(base_path, "baseline-gspo-dapo-math-minibsz32/training_data"),
        # "skip-right-skip-limits10-grpo-dapo-math": os.path.join(base_path, "skip-right-skip-limits10-grpo-dapo-math/training_data"),
        # "skip-right-skip-limits10-gspo-dapo-math": os.path.join(base_path, "skip-right-skip-limits10-gspo-dapo-math/training_data")

    }
    
    max_length = 8192
    print(f"最大长度阈值: {max_length}\n")
    
    # 分析所有实验的数据
    all_results = {}
    for exp_name, data_dir in experiments.items():
        print(f"\n{'='*60}")
        print(f"分析实验: {exp_name}")
        print(f"数据路径: {data_dir}")
        print(f"{'='*60}")
        
        results = analyze_training_data(data_dir, max_length)
        if results:
            all_results[exp_name] = results
            
            # 打印统计摘要
            print(f"\n=== {exp_name} 统计摘要 ===")
            for step in sorted(results.keys()):
                print(f"Step {step}: "
                      f"错误平均长度={results[step]['avg_incorrect_length']:.2f}, "
                      f"未到达最大长度比例={results[step]['ratio_below_max']:.2%}, "
                      f"错误输出总数={results[step]['total_incorrect']}")
        else:
            print(f"警告: {exp_name} 分析失败或没有找到数据")
    
    # 绘制组合图表
    if len(all_results) > 0:
        output_path = os.path.join(base_path, 'incorrect_length_analysis_combined.pdf')
        plot_combined_results(all_results, output_path, dpi=600)
    else:
        print("\n错误: 没有可绘制的数据")

if __name__ == "__main__":
    main()
