import json
import os
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer
from math_verify import parse, verify
import pandas as pd
from tqdm import tqdm
import signal
import sys
import seaborn as sns
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import torch
from functools import partial
import time
from typing import List, Dict, Tuple, Optional

# 设置更美观的字体和样式
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'sans-serif']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'

# 设置seaborn样式
sns.set_style("whitegrid")
sns.set_palette("husl")

# 设置特定分析的数据集
TARGET_DATASOURCES = ['aime', 'aime25', 'amc', 'math', 'olympiad_bench']

# 全局变量用于GPU加速
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {DEVICE}")

def timeout(timeout_seconds: int = 10):
    """超时装饰器"""
    if os.name == "posix":
        def decorator(func):
            def handler(signum, frame):
                raise TimeoutError("verify timed out!")
            def wrapper(*args, **kwargs):
                old_handler = signal.getsignal(signal.SIGALRM)
                signal.signal(signal.SIGALRM, handler)
                signal.alarm(timeout_seconds)
                try:
                    return func(*args, **kwargs)
                finally:
                    signal.alarm(0)
                    signal.signal(signal.SIGALRM, old_handler)
            return wrapper
        return decorator
    else:
        def decorator(func):
            return func
        return decorator

@timeout(timeout_seconds=10)
def labeling_responses_batch(responses: List[str], golden_answers: List[str]) -> List[bool]:
    """批量验证答案 - 优化版本"""
    try:
        # 批量解析预测答案
        predict_answers = [parse(response) for response in responses]
        
        # 批量解析正确答案
        golden_parsed = [parse("$" + answer + "$") for answer in golden_answers]
        
        # 批量验证
        labels = [verify(golden, predict) for golden, predict in zip(golden_parsed, predict_answers)]
        return labels
    except Exception as e:
        print(f"批量验证过程中出错: {e}")
        return [False] * len(responses)

def truncate_text_by_tokens_batch(texts: List[str], tokenizer, max_tokens: int) -> List[str]:
    """批量按token数量截断文本"""
    try:
        truncated_texts = []
        for text in texts:
            tokens = tokenizer.encode(text, add_special_tokens=False)
            if len(tokens) > max_tokens:
                tokens = tokens[:max_tokens]
            truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
            truncated_texts.append(truncated_text)
        return truncated_texts
    except Exception as e:
        print(f"批量截断文本时出错: {e}")
        return texts

def process_single_file_batch(file_path: str, tokenizer, max_tokens: int, batch_size: int = 32) -> Dict:
    """处理单个文件的批量版本 - 支持GPU加速"""
    if not os.path.exists(file_path):
        print(f"文件不存在: {file_path}")
        return {}

    # 按datasource分组统计
    datasource_stats = {}
    length_stats = {
        # 这里统计的是在当前 max_tokens 截断后的 token 长度
        'truncated_lengths': [],
        'datasource_lengths': {}
    }

    # 批量处理数据
    batch_data = []
    batch_responses = []
    batch_answers = []
    batch_datasources = []
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc=f"处理 {os.path.basename(file_path)} (max_tokens={max_tokens})"):
                try:
                    data = json.loads(line.strip())
                    generated_text = data.get('generated_text', '')
                    answer = data.get('answer', '')
                    datasource = data.get('data_source', 'unknown')
                    
                    if not generated_text or not answer:
                        continue
                    
                    # 初始化datasource统计
                    if datasource not in datasource_stats:
                        datasource_stats[datasource] = {'correct': 0, 'total': 0}

                    # 收集批量数据
                    batch_data.append({
                        'generated_text': generated_text,
                        'answer': answer,
                        'datasource': datasource
                    })
                    
                    batch_responses.append(generated_text)
                    batch_answers.append(answer)
                    batch_datasources.append(datasource)
                    
                    # 当达到批量大小时，处理一批
                    if len(batch_data) >= batch_size:
                        process_batch(
                            batch_data,
                            batch_responses,
                            batch_answers,
                            batch_datasources,
                            tokenizer,
                            max_tokens,
                            datasource_stats,
                            length_stats,
                        )
                        batch_data = []
                        batch_responses = []
                        batch_answers = []
                        batch_datasources = []
                        
                except json.JSONDecodeError:
                    continue
                except Exception as e:
                    print(f"处理单行数据时出错: {e}")
                    continue
        
        # 处理剩余的批次
        if batch_data:
            process_batch(
                batch_data,
                batch_responses,
                batch_answers,
                batch_datasources,
                tokenizer,
                max_tokens,
                datasource_stats,
                length_stats,
            )
    
    except Exception as e:
        print(f"读取文件时出错: {e}")
        return {}
    
    # 计算每个datasource的准确率
    datasource_accuracies = {}
    for datasource, stats in datasource_stats.items():
        accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0.0
        datasource_accuracies[datasource] = accuracy
        print(f"  {datasource}: 准确率 {accuracy:.4f} ({stats['correct']}/{stats['total']})")
    
    # 计算Macro Average
    if datasource_accuracies:
        macro_avg = sum(datasource_accuracies.values()) / len(datasource_accuracies)
        print(f"文件: {os.path.basename(file_path)}, 截断长度: {max_tokens}, Macro Average: {macro_avg:.4f}")
    else:
        macro_avg = 0.0
        print(f"文件: {os.path.basename(file_path)}, 截断长度: {max_tokens}, 无有效数据")
    
    # 计算平均截断后长度
    avg_truncated_length = np.mean(length_stats['truncated_lengths']) if length_stats['truncated_lengths'] else 0
    print(f"  平均截断后长度: {avg_truncated_length:.1f} tokens")
    
    # 按datasource计算平均截断后长度
    datasource_avg_lengths = {}
    for datasource, lengths in length_stats['datasource_lengths'].items():
        if lengths['truncated_lengths']:
            datasource_avg_lengths[datasource] = {
                'avg_truncated': np.mean(lengths['truncated_lengths'])
            }
            print(f"  {datasource} - 平均截断后长度: {datasource_avg_lengths[datasource]['avg_truncated']:.1f}")
    
    return {
        'macro_avg': macro_avg,
        'datasource_accuracies': datasource_accuracies,
        'datasource_stats': datasource_stats,
        'length_stats': {
            # 这里保存的是当前截断长度下的平均“截断后长度”
            'avg_truncated_length': avg_truncated_length,
            'datasource_avg_lengths': datasource_avg_lengths
        }
    }

def process_batch(
    batch_data: List[Dict],
    batch_responses: List[str],
    batch_answers: List[str],
    batch_datasources: List[str],
    tokenizer,
    max_tokens: int,
    datasource_stats: Dict,
    length_stats: Dict,
):
    """处理一个批次的数据"""
    try:
        # 批量截断文本
        truncated_texts = truncate_text_by_tokens_batch(batch_responses, tokenizer, max_tokens)
        
        # 统计截断后的 token 长度（用于计算当前截断长度下的平均长度）
        for truncated_text, datasource in zip(truncated_texts, batch_datasources):
            tokens_truncated = tokenizer.encode(truncated_text, add_special_tokens=False)
            truncated_len = len(tokens_truncated)
            length_stats['truncated_lengths'].append(truncated_len)

            if datasource not in length_stats['datasource_lengths']:
                length_stats['datasource_lengths'][datasource] = {
                    'truncated_lengths': []
                }
            length_stats['datasource_lengths'][datasource]['truncated_lengths'].append(truncated_len)

        # 批量验证答案（基于截断后的文本）
        labels = labeling_responses_batch(truncated_texts, batch_answers)
        
        # 更新统计
        for i, (label, datasource) in enumerate(zip(labels, batch_datasources)):
            if label:
                datasource_stats[datasource]['correct'] += 1
            datasource_stats[datasource]['total'] += 1
            
    except Exception as e:
        print(f"处理批次时出错: {e}")

def process_file_worker(args):
    """工作进程函数 - 处理单个文件"""
    file_path, tokenizer_path, max_tokens, batch_size = args
    
    # 在每个进程中重新加载tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    except:
        print("无法加载tokenizer，使用默认tokenizer")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
    
    return process_single_file_batch(file_path, tokenizer, max_tokens, batch_size)

def analyze_models_truncation_accuracy_parallel(files: List[str], truncation_lengths: List[int], 
                                              num_processes: int = None, batch_size: int = 32,
                                              file_labels: Dict[str, str] = None):
    """并行分析不同模型在不同截断长度下的准确率"""
    
    if num_processes is None:
        num_processes = min(mp.cpu_count(), len(files))
    
    print(f"使用 {num_processes} 个进程并行处理")
    print(f"批量大小: {batch_size}")
    
    # 初始化tokenizer路径
    try:
        tokenizer_path = "/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B-Base"
    except:
        print("无法加载Qwen tokenizer，使用默认tokenizer")
        tokenizer_path = "gpt2"
    
    # 存储结果
    results = {}
    datasource_results = {}
    length_results = {}
    
    # 为每个截断长度并行处理所有文件
    for max_tokens in truncation_lengths:
        print(f"\n处理截断长度: {max_tokens}")
        print("=" * 50)
        
        # 准备参数
        args_list = [(file_path, tokenizer_path, max_tokens, batch_size) for file_path in files]
        
        # 并行处理
        with ProcessPoolExecutor(max_workers=num_processes) as executor:
            file_results = list(executor.map(process_file_worker, args_list))
        
        # 处理结果
        for i, (file_path, result) in enumerate(zip(files, file_results)):
            if not result:
                continue
            
            # 优先使用手动指定的标签
            if file_labels and file_path in file_labels:
                model_name = file_labels[file_path]
            else:
                model_name = os.path.basename(file_path).replace('.jsonl', '')
            
            if model_name not in results:
                results[model_name] = {}
                datasource_results[model_name] = {}
                length_results[model_name] = {}
            
            results[model_name][max_tokens] = result['macro_avg']
            datasource_results[model_name][max_tokens] = result['datasource_accuracies']
            length_results[model_name][max_tokens] = result['length_stats']
    
    return results, datasource_results, length_results

def plot_truncation_accuracy_optimized(results: dict, truncation_lengths: list, 
                                      datasource_results: dict = None, length_results: dict = None):
    """优化的绘图函数"""
    
    # 创建更大的图形
    fig, ax = plt.subplots(figsize=(14, 10))
    
    # 使用更美观的颜色方案
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E', '#7209B7', '#F77F00', '#FCBF49']
    
    # 定义线条样式
    line_styles = ['-', '--', '-.', ':']
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p']
    
    for i, (model_name, accuracies) in enumerate(results.items()):
        if not accuracies:
            continue
            
        # 提取准确率数据
        acc_values = [accuracies.get(length, 0) for length in truncation_lengths]
        
        # 绘制折线图
        line = ax.plot(truncation_lengths, acc_values, 
                      marker=markers[i % len(markers)], 
                      linewidth=3, 
                      markersize=10,
                      color=colors[i % len(colors)],
                      linestyle=line_styles[i % len(line_styles)],
                      label=model_name,
                      markerfacecolor='white',
                      markeredgewidth=2,
                      markeredgecolor=colors[i % len(colors)],
                      alpha=0.9)
    
    # 设置坐标轴
    ax.set_xlabel('Test-time compute (tokens)', fontsize=16, fontweight='bold', labelpad=15)
    ax.set_ylabel('Macro Average Accuracy', fontsize=16, fontweight='bold', labelpad=15)
    ax.set_title('Model Macro Average Accuracy Comparison at Different Test-time Compute (tokens)', 
                fontsize=18, fontweight='bold', pad=25)
    
    # 设置网格
    ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
    ax.set_axisbelow(True)
    
    # 设置坐标轴范围和刻度
    # 动态计算 y 轴范围，给数据留出 10% 的余量
    all_accs = [accuracies.get(l, 0) for accuracies in results.values() for l in truncation_lengths if l in accuracies]
    if all_accs:
        ymin, ymax = min(all_accs), max(all_accs)
        padding = (ymax - ymin) * 0.2 if ymax > ymin else 0.05
        ax.set_ylim(max(0, ymin - padding), min(1.0, ymax + padding))
    else:
        ax.set_ylim(0, 1.05)

    ax.set_xlim(min(truncation_lengths) * 0.9, max(truncation_lengths) * 1.1)
    
    # 设置x轴刻度
    ax.set_xticks(truncation_lengths)
    ax.set_xticklabels([f'{x:,}' for x in truncation_lengths], fontsize=12)
    
    # 设置y轴刻度 (更细密的刻度)
    current_ylim = ax.get_ylim()
    ax.set_yticks(np.linspace(current_ylim[0], current_ylim[1], 10))
    ax.yaxis.set_major_formatter(plt.FormatStrFormatter('%.2f'))
    
    # 美化图例
    legend = ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', 
                      fontsize=12, frameon=True, fancybox=True, shadow=True,
                      title='Models', title_fontsize=14)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_alpha(0.9)
    legend.get_title().set_fontweight('bold')
    
    # 添加参考线
    ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5, linewidth=1)
    ax.axhline(y=0.7, color='gray', linestyle=':', alpha=0.5, linewidth=1)
    
    # 调整布局
    plt.tight_layout()
    
    # 保存图片
    output_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/truncation_accuracy_plot_optimized.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    print(f"\n优化折线图已保存到: {output_path}")
    
    plt.show()

def plot_specific_datasources_grid(datasource_results: dict, truncation_lengths: list):
    """为选定的数据集绘制子图网格"""
    import math
    # 过滤掉完全没数据的数据集
    available_ds = []
    for ds in TARGET_DATASOURCES:
        has_data = False
        for model_res in datasource_results.values():
            for length_res in model_res.values():
                if ds in length_res:
                    has_data = True
                    break
            if has_data: break
        if has_data: available_ds.append(ds)
    
    if not available_ds:
        print(f"\n警告：在结果中未找到任何指定的数据集 {TARGET_DATASOURCES}")
        return

    num_ds = len(available_ds)
    cols = min(3, num_ds)
    rows = math.ceil(num_ds / cols)
    
    fig, axes = plt.subplots(rows, cols, figsize=(20, 6 * rows), squeeze=False)
    axes_flat = axes.flatten()
    
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E', '#7209B7', '#F77F00', '#FCBF49']
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p']
    
    for idx, ds in enumerate(available_ds):
        ax = axes_flat[idx]
        
        for i, (model_name, model_res) in enumerate(datasource_results.items()):
            acc_values = []
            lengths_present = []
            for length in truncation_lengths:
                if length in model_res and ds in model_res[length]:
                    acc_values.append(model_res[length][ds])
                    lengths_present.append(length)
            
            if acc_values:
                ax.plot(lengths_present, acc_values, 
                        marker=markers[i % len(markers)], 
                        label=model_name,
                        color=colors[i % len(colors)],
                        linewidth=2,
                        markersize=8)
        
        ax.set_title(f'Dataset: {ds}', fontsize=16, fontweight='bold')
        ax.set_xlabel('Test-time compute (tokens)', fontsize=12)
        ax.set_ylabel('Accuracy', fontsize=12)
        
        # 动态设置子图的 y 轴范围
        ds_accs = []
        for model_res in datasource_results.values():
            for length in truncation_lengths:
                if length in model_res and ds in model_res[length]:
                    ds_accs.append(model_res[length][ds])
        if ds_accs:
            ymin, ymax = min(ds_accs), max(ds_accs)
            padding = (ymax - ymin) * 0.2 if ymax > ymin else 0.05
            ax.set_ylim(max(0, ymin - padding), min(1.0, ymax + padding))
        else:
            ax.set_ylim(0, 1.05)

        ax.set_xticks(truncation_lengths)
        ax.grid(True, alpha=0.3, linestyle='--')
        if idx == 0:
            ax.legend(fontsize=10, loc='best')
    
    # 隐藏多余的子图
    for j in range(len(available_ds), len(axes_flat)):
        axes_flat[j].axis('off')
        
    plt.tight_layout()
    output_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/specific_datasets_grid.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"特定数据集网格图已保存到: {output_path}")

def plot_selected_macro_average(datasource_results: dict, truncation_lengths: list):
    """计算并绘制选定数据集的平均准确率"""
    fig, ax = plt.subplots(figsize=(14, 10))
    
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E', '#7209B7', '#F77F00', '#FCBF49']
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p']
    
    for i, (model_name, model_res) in enumerate(datasource_results.items()):
        avg_accs = []
        valid_lengths = []
        
        for length in truncation_lengths:
            if length in model_res:
                # 只计算存在于 TARGET_DATASOURCES 中的数据集
                ds_accs = [model_res[length][ds] for ds in TARGET_DATASOURCES if ds in model_res[length]]
                if ds_accs:
                    avg_acc = sum(ds_accs) / len(ds_accs)
                    avg_accs.append(avg_acc)
                    valid_lengths.append(length)
        
        if avg_accs:
            ax.plot(valid_lengths, avg_accs, 
                    marker=markers[i % len(markers)], 
                    linewidth=4, 
                    markersize=12,
                    label=model_name,
                    color=colors[i % len(colors)],
                    markerfacecolor='white',
                    markeredgewidth=2)

    title_ds = ", ".join(TARGET_DATASOURCES)
    ax.set_title(f'Macro Average Accuracy for Selected Datasets\n({title_ds})', 
                fontsize=18, fontweight='bold', pad=20)
    ax.set_xlabel('Test-time compute (tokens)', fontsize=16, fontweight='bold')
    ax.set_ylabel('Macro Average Accuracy', fontsize=16, fontweight='bold')
    
    # 动态设置 y 轴范围
    all_selected_accs = []
    for model_res in datasource_results.values():
        for length in truncation_lengths:
            if length in model_res:
                ds_accs = [model_res[length][ds] for ds in TARGET_DATASOURCES if ds in model_res[length]]
                if ds_accs:
                    all_selected_accs.append(sum(ds_accs) / len(ds_accs))
    
    if all_selected_accs:
        ymin, ymax = min(all_selected_accs), max(all_selected_accs)
        padding = (ymax - ymin) * 0.3 if ymax > ymin else 0.05
        ax.set_ylim(max(0, ymin - padding), min(1.0, ymax + padding))
    else:
        ax.set_ylim(0, 1.05)

    ax.set_xlim(min(truncation_lengths) * 0.9, max(truncation_lengths) * 1.1)
    ax.set_xticks(truncation_lengths)
    ax.grid(True, alpha=0.3, linestyle='-')
    
    # 设置更细密的 y 轴刻度
    current_ylim = ax.get_ylim()
    ax.set_yticks(np.linspace(current_ylim[0], current_ylim[1], 8))
    ax.yaxis.set_major_formatter(plt.FormatStrFormatter('%.2f'))
    
    legend = ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', 
                      fontsize=12, frameon=True, shadow=True)
    
    plt.tight_layout()
    output_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/selected_macro_avg_tts.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"选定数据集平均 TTS 图已保存到: {output_path}")

def save_results_to_csv_optimized(results: dict, truncation_lengths: list, 
                                 datasource_results: dict = None, length_results: dict = None):
    """优化的CSV保存函数 - 支持Macro Average、按datasource分组的结果和长度统计"""
    
    # 创建Macro Average结果表格
    macro_data = []
    for model_name, accuracies in results.items():
        row = {'Model': model_name}
        for length in truncation_lengths:
            row[f'Macro_Avg_{length}'] = accuracies.get(length, 0)
        macro_data.append(row)
    
    macro_df = pd.DataFrame(macro_data)
    
    # 打印Macro Average结果表格
    print("\nMacro Average准确率结果表格:")
    print("=" * 80)
    print(macro_df.to_string(index=False, float_format='%.4f'))
    
    # 保存到CSV文件
    csv_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/truncation_accuracy_results_optimized.csv"
    macro_df.to_csv(csv_path, index=False)
    print(f"\n优化Macro Average结果已保存到: {csv_path}")
    
    # 如果有datasource结果，创建详细的结果表格
    if datasource_results:
        print("\n按Datasource分组的详细结果:")
        print("=" * 80)
        
        # 收集所有datasource
        all_datasources = set()
        for model_results in datasource_results.values():
            for truncation_results in model_results.values():
                all_datasources.update(truncation_results.keys())
        
        # 为每个datasource创建表格
        for datasource in sorted(all_datasources):
            print(f"\nDatasource: {datasource}")
            print("-" * 40)
            
            datasource_data = []
            for model_name, model_results in datasource_results.items():
                row = {'Model': model_name}
                for length in truncation_lengths:
                    if length in model_results and datasource in model_results[length]:
                        row[f'Accuracy_{length}'] = model_results[length][datasource]
                    else:
                        row[f'Accuracy_{length}'] = 0.0
                datasource_data.append(row)
            
            datasource_df = pd.DataFrame(datasource_data)
            print(datasource_df.to_string(index=False, float_format='%.4f'))
        
        # 保存详细的datasource结果
        detailed_csv_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/datasource_breakdown_results_optimized.csv"
        
        # 创建详细的DataFrame
        detailed_data = []
        for model_name, model_results in datasource_results.items():
            for length in truncation_lengths:
                if length in model_results:
                    for datasource, accuracy in model_results[length].items():
                        detailed_data.append({
                            'Model': model_name,
                            'Truncation_Length': length,
                            'Datasource': datasource,
                            'Accuracy': accuracy
                        })
        
        detailed_df = pd.DataFrame(detailed_data)
        detailed_df.to_csv(detailed_csv_path, index=False)
        print(f"详细Datasource结果已保存到: {detailed_csv_path}")
    
    # 如果有长度统计结果，保存长度统计
    if length_results:
        print("\n长度统计结果:")
        print("=" * 80)
        
        # 创建长度统计表格
        length_data = []
        for model_name, model_lengths in length_results.items():
            row = {'Model': model_name}
            for length in truncation_lengths:
                if length in model_lengths:
                    # 这里保存的是在当前截断长度 max_tokens 下的平均“截断后长度”
                    row[f'Avg_Truncated_Length_{length}'] = model_lengths[length]['avg_truncated_length']
                else:
                    row[f'Avg_Truncated_Length_{length}'] = 0
            length_data.append(row)
        
        length_df = pd.DataFrame(length_data)
        print(length_df.to_string(index=False, float_format='%.1f'))
        
        # 保存长度统计到CSV（统计的是截断后长度）
        length_csv_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/length_statistics_truncated_results_optimized.csv"
        length_df.to_csv(length_csv_path, index=False)
        print(f"\n长度统计结果已保存到: {length_csv_path}")
        
        # 按datasource保存详细长度统计
        if any('datasource_avg_lengths' in model_lengths.get(length, {}) for model_lengths in length_results.values() for length in truncation_lengths):
            detailed_length_data = []
            for model_name, model_lengths in length_results.items():
                for length in truncation_lengths:
                    if length in model_lengths and 'datasource_avg_lengths' in model_lengths[length]:
                        for datasource, lengths in model_lengths[length]['datasource_avg_lengths'].items():
                            detailed_length_data.append({
                                'Model': model_name,
                                'Truncation_Length': length,
                                'Datasource': datasource,
                                # 这里同样是截断后长度
                                'Avg_Truncated_Length': lengths['avg_truncated']
                            })
            
            if detailed_length_data:
                detailed_length_df = pd.DataFrame(detailed_length_data)
                detailed_length_csv_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/detailed_length_statistics_truncated_results_optimized.csv"
                detailed_length_df.to_csv(detailed_length_csv_path, index=False)
                print(f"详细长度统计结果已保存到: {detailed_length_csv_path}")

if __name__ == "__main__":
    print("开始并行分析不同模型在不同Token截断长度下的准确率...")
    print("=" * 60)
    
    # 文件列表
    files = [
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_nov/Qwen3-4B-Base-valid-all_32768_test.jsonl",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/grpo-step430-valid-all_32768_test.jsonl",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-step500-valid-all_32768_test.jsonl",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/baseline-dapo-math-step350-valid-all_32768_test.jsonl",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/grpo-skip-right-step320-valid-all_32768_test.jsonl",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-skip-right-step600-valid-all_32768_test.jsonl",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-stage2-dapo-math-add1k-max12k-step100-valid-all_32768_test.jsonl",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/skip-right-limits10-dapo-math-step330-valid-all_32768_test.jsonl"
    ]
    
    # 手动指定每个文件的标签 (Label)
    file_labels = {
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_nov/Qwen3-4B-Base-valid-all_32768_test.jsonl": "Qwen3-4B-Base",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/grpo-step430-valid-all_32768_test.jsonl": "GRPO (Step 430)",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-step500-valid-all_32768_test.jsonl": "GSPO (Step 500)",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/baseline-dapo-math-step350-valid-all_32768_test.jsonl": "GRPO High clip ratio (Step 350)",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/grpo-skip-right-step320-valid-all_32768_test.jsonl": "GRPO + ours (Step 320)",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-skip-right-step600-valid-all_32768_test.jsonl": "GSPO + ours (Step 600)",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/skip-right-limits10-dapo-math-step330-valid-all_32768_test.jsonl": "GRPO High clip ratio + ours (Step 330)",
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-stage2-dapo-math-add1k-max12k-step100-valid-all_32768_test.jsonl": "GSPO Stage 2 (Step 100)"
    }
    
    # 截断长度列表
    truncation_lengths = [4096, 8192, 12288, 16384, 24576, 32768]
    
    # 并行处理参数
    num_processes = 4#min(mp.cpu_count(), len(files))  # 使用CPU核心数或文件数中较小的
    batch_size = 64  # 批量大小，可以根据GPU内存调整
    
    print(f"并行处理配置:")
    print(f"  进程数: {num_processes}")
    print(f"  批量大小: {batch_size}")
    print(f"  截断长度: {truncation_lengths}")
    print(f"  文件数: {len(files)}")
    
    try:
        start_time = time.time()
        results, datasource_results, length_results = analyze_models_truncation_accuracy_parallel(
            files, truncation_lengths, num_processes, batch_size, file_labels
        )
        end_time = time.time()
        
        print(f"\n并行处理完成！总耗时: {end_time - start_time:.2f} 秒")
        
        # 绘制结果
        plot_truncation_accuracy_optimized(results, truncation_lengths, datasource_results, length_results)
        
        # 绘制特定数据集的结果
        plot_specific_datasources_grid(datasource_results, truncation_lengths)
        plot_selected_macro_average(datasource_results, truncation_lengths)
        
        # 保存结果
        save_results_to_csv_optimized(results, truncation_lengths, datasource_results, length_results)
        
        print("\n分析完成！")
        
    except KeyboardInterrupt:
        print("\n用户中断了程序")
        sys.exit(0)
    except Exception as e:
        print(f"\n程序执行过程中出现错误: {e}")
        import traceback
        traceback.print_exc()
