import json
import os
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer
from math_verify import parse, verify
import pandas as pd
from tqdm import tqdm
import signal
import sys
import seaborn as sns

# 设置更美观的字体和样式
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'sans-serif']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'

# 设置seaborn样式
sns.set_style("whitegrid")
sns.set_palette("husl")

def timeout(timeout_seconds: int = 10):
    """超时装饰器"""
    if os.name == "posix":
        def decorator(func):
            def handler(signum, frame):
                raise TimeoutError("verify timed out!")
            def wrapper(*args, **kwargs):
                old_handler = signal.getsignal(signal.SIGALRM)
                signal.signal(signal.SIGALRM, handler)
                signal.alarm(timeout_seconds)
                try:
                    return func(*args, **kwargs)
                finally:
                    signal.alarm(0)
                    signal.signal(signal.SIGALRM, old_handler)
            return wrapper
        return decorator
    else:
        def decorator(func):
            return func
        return decorator

@timeout(timeout_seconds=10)
def labeling_responses(responses: list[str], golden_answer: str):
    """使用math_verify验证答案"""
    try:
        predict_answers = list(map(parse, responses))
        golden_answers = list(map(parse, ["$" + golden_answer + "$"] * len(responses)))
        labels = list(map(verify, golden_answers, predict_answers))
        return labels
    except Exception as e:
        print(f"验证过程中出错: {e}")
        return [False] * len(responses)

def truncate_text_by_tokens(text: str, tokenizer, max_tokens: int) -> str:
    """按照token数量截断文本"""
    try:
        # 将文本编码为tokens
        tokens = tokenizer.encode(text, add_special_tokens=False)
        
        # 如果token数量超过限制，则截断
        if len(tokens) > max_tokens:
            tokens = tokens[:max_tokens]
        
        # 将tokens解码回文本
        truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
        return truncated_text
    except Exception as e:
        print(f"截断文本时出错: {e}")
        return text

def calculate_accuracy_for_truncation(file_path: str, tokenizer, max_tokens: int) -> dict:
    """计算指定截断长度下的准确率 - 支持按datasource分组的Macro Average和长度统计"""
    if not os.path.exists(file_path):
        print(f"文件不存在: {file_path}")
        return {}
    
    # 按datasource分组统计
    datasource_stats = {}
    # 长度统计
    length_stats = {
        'original_lengths': [],
        'datasource_lengths': {}
    }
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc=f"处理 {os.path.basename(file_path)} (max_tokens={max_tokens})"):
                try:
                    data = json.loads(line.strip())
                    generated_text = data.get('generated_text', '')
                    answer = data.get('answer', '')
                    datasource = data.get('data_source', 'unknown')
                    
                    if not generated_text or not answer:
                        continue
                    
                    # 初始化datasource统计
                    if datasource not in datasource_stats:
                        datasource_stats[datasource] = {'correct': 0, 'total': 0}
                    
                    # 计算原始长度
                    original_tokens = tokenizer.encode(generated_text, add_special_tokens=False)
                    original_length = len(original_tokens)
                    length_stats['original_lengths'].append(original_length)
                    
                    # 截断生成文本
                    truncated_text = truncate_text_by_tokens(generated_text, tokenizer, max_tokens)
                    
                    
                    # 按datasource统计长度
                    if datasource not in length_stats['datasource_lengths']:
                        length_stats['datasource_lengths'][datasource] = {
                            'original_lengths': []
                        }
                    length_stats['datasource_lengths'][datasource]['original_lengths'].append(original_length)
                    
                    
                    # 验证答案
                    labels = labeling_responses([truncated_text], answer)
                    
                    if labels and labels[0]:
                        datasource_stats[datasource]['correct'] += 1
                    
                    datasource_stats[datasource]['total'] += 1
                    
                except json.JSONDecodeError:
                    continue
                except Exception as e:
                    print(f"处理单行数据时出错: {e}")
                    continue
    
    except Exception as e:
        print(f"读取文件时出错: {e}")
        return {}
    
    # 计算每个datasource的准确率
    datasource_accuracies = {}
    for datasource, stats in datasource_stats.items():
        accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0.0
        datasource_accuracies[datasource] = accuracy
        print(f"  {datasource}: 准确率 {accuracy:.4f} ({stats['correct']}/{stats['total']})")
    
    # 计算Macro Average (所有datasource准确率的平均值)
    if datasource_accuracies:
        macro_avg = sum(datasource_accuracies.values()) / len(datasource_accuracies)
        print(f"文件: {os.path.basename(file_path)}, 截断长度: {max_tokens}, Macro Average: {macro_avg:.4f}")
    else:
        macro_avg = 0.0
        print(f"文件: {os.path.basename(file_path)}, 截断长度: {max_tokens}, 无有效数据")
    
    # 计算平均长度
    avg_original_length = np.mean(length_stats['original_lengths']) if length_stats['original_lengths'] else 0
    
    print(f"  平均原始长度: {avg_original_length:.1f} tokens")

    
    # 按datasource计算平均长度
    datasource_avg_lengths = {}
    for datasource, lengths in length_stats['datasource_lengths'].items():
        if lengths['original_lengths']:
            datasource_avg_lengths[datasource] = {
                'avg_original': np.mean(lengths['original_lengths'])
            }
            print(f"  {datasource} - 平均原始长度: {datasource_avg_lengths[datasource]['avg_original']:.1f}")
    
    return {
        'macro_avg': macro_avg,
        'datasource_accuracies': datasource_accuracies,
        'datasource_stats': datasource_stats,
        'length_stats': {
            'avg_original_length': avg_original_length,
            'datasource_avg_lengths': datasource_avg_lengths
        }
    }

def analyze_models_truncation_accuracy():
    """分析不同模型在不同截断长度下的准确率"""
    
    # 文件列表 - 你需要根据实际情况修改这些路径
    files = [
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/Qwen3-8B-baseline_32000_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/grpo-qwen3-8b-deepscaler-BASELINE_32000_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/grpo-qwen3-8b-deepscaler-ori-length-l1-wolength-prompt_32000_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/grpo-qwen3-8b-deepscaler-added1k-l1-truncated-wolength-prompt_32000_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/grpo-qwen3-8b-deepscaler-added1k-l1-truncated-wolength-prompt-filtered7192_32000_test.jsonl"
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polaris-1k_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-polaris-1k-temp0.6_32768_test.jsonl",

        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polrias30k-16k-analysis-step100-polaris-1k_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polrias30k-17k-temp1.4-step100-polaris-1k_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-valid-temp0.6_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-valid_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polrias30k-16k-analysis-step100-valid_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polrias30k-17k-temp1.4-step100-valid_32768_test.jsonl"

        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polaris-1k-temp1.0_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-polaris-1k-temp1.0_32768_test.jsonl",

        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polrias30k-16k-analysis-step100-polaris-1k-temp1.0_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polrias30k-17k-temp1.4-step100-polaris-1k-temp1.0_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-valid-temp1.0_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-valid-temp1.0_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polrias30k-16k-analysis-step100-valid-temp1.0_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-polrias30k-17k-temp1.4-step100-valid-temp1.0_32768_test.jsonl"
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-grpo-16k-60step-polaris1k_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-grpo-8k-70step-polaris1k_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-grpo-16k-60step-valid_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-grpo-8k-70step-valid_32768_test.jsonl",
        
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-BASELINE-8k-step330-valid_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-BASELINE-16k-step290-valid_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-ori-length-step150-valid_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-add1k-step200-valid_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-8k-step150-valid_32768_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-8k-max8k-step230-valid_32768_test.jsonl"
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-40step-valid_32768_test.jsonl"
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-overlong-filter-70step-valid_32768_test.jsonl"
    #     "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/Qwen3-4B-Base-deepscaler1k_32768_test.jsonl",
    #    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-40step-deepscaler1k_32768_test.jsonl",
    #     "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-overlong-filter-70step-deepscaler1k_32768_test.jsonl"
    # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/Qwen3-4B-Base-deepscaler1k_32768_test.jsonl",
    # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-40step-deepscaler1k_32768_test.jsonl",
    # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-overlong-filter-70step-deepscaler1k_32768_test.jsonl",
    # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-overlong-filter-150step-deepscaler1k_32768_test.jsonl"
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/Qwen3-4B-Base_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-BASELINE-8k-step330-valid_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-BASELINE-16k-step290-valid_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-ori-length-step150-valid_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-add1k-step200-valid_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-8k-step150-valid_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-8k-max8k-step230-valid_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-40step-valid_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-overlong-filter-70step-valid_32768_test.jsonl",
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-overlong-filter-150step-valid_32768_test.jsonl"
    




        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B_32000_test.jsonl",
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/grpo-qwen3-4b-deepscaler-added1k-l1-truncated-wolength-prompt-filtered7192_32000_test.jsonl"
        # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/Qwen3-8B-baseline-all_32000_test.jsonl"
    ]
    
    # 截断长度列表
    truncation_lengths = [4096, 8192, 12288, 16384, 32768]
    
    # 初始化tokenizer (使用Qwen tokenizer)
    try:
        tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B-Base")
    except:
        print("无法加载Qwen tokenizer，使用默认tokenizer")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
    
    # 存储结果
    results = {}
    datasource_results = {}  # 存储按datasource分组的结果
    length_results = {}  # 存储长度统计结果
    
    # 为每个文件计算不同截断长度下的准确率
    for file_path in files:
        if not os.path.exists(file_path):
            print(f"跳过不存在的文件: {file_path}")
            continue
            
        model_name = os.path.basename(file_path).replace('.jsonl', '')
        results[model_name] = {}
        datasource_results[model_name] = {}
        length_results[model_name] = {}
        
        print(f"\n分析模型: {model_name}")
        print("=" * 50)
        
        for max_tokens in truncation_lengths:
            result = calculate_accuracy_for_truncation(file_path, tokenizer, max_tokens)
            if result:
                results[model_name][max_tokens] = result['macro_avg']
                datasource_results[model_name][max_tokens] = result['datasource_accuracies']
                length_results[model_name][max_tokens] = result['length_stats']
    
    # 绘制折线图
    plot_truncation_accuracy(results, truncation_lengths, datasource_results, length_results)
    
    # 保存结果到CSV
    save_results_to_csv(results, truncation_lengths, datasource_results, length_results)
    
    return results, datasource_results, length_results

def plot_truncation_accuracy(results: dict, truncation_lengths: list, datasource_results: dict = None, length_results: dict = None):
    """绘制截断长度vs准确率的折线图 - 美化版本，支持Macro Average、按datasource分组显示和长度统计"""
    
    # 创建更大的图形
    fig, ax = plt.subplots(figsize=(14, 10))
    
    # 使用更美观的颜色方案
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E', '#7209B7', '#F77F00', '#FCBF49']
    
    # 定义线条样式
    line_styles = ['-', '--', '-.', ':']
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p']
    
    # 存储所有数据点用于智能标签定位
    all_points = []
    
    for i, (model_name, accuracies) in enumerate(results.items()):
        if not accuracies:  # 跳过空结果
            continue
            
        # 提取准确率数据
        acc_values = [accuracies.get(length, 0) for length in truncation_lengths]
        
        # 绘制折线图
        line = ax.plot(truncation_lengths, acc_values, 
                      marker=markers[i % len(markers)], 
                      linewidth=3, 
                      markersize=10,
                      color=colors[i % len(colors)],
                      linestyle=line_styles[i % len(line_styles)],
                      label=model_name,
                      markerfacecolor='white',
                      markeredgewidth=2,
                      markeredgecolor=colors[i % len(colors)],
                      alpha=0.9)
        
        # 收集所有数据点
        for j, (x, y) in enumerate(zip(truncation_lengths, acc_values)):
            all_points.append((x, y, model_name, i, j))
    
    # 设置坐标轴
    ax.set_xlabel('Token Truncation Length', fontsize=16, fontweight='bold', labelpad=15)
    ax.set_ylabel('Macro Average Accuracy', fontsize=16, fontweight='bold', labelpad=15)
    ax.set_title('Model Macro Average Accuracy Comparison at Different Token Truncation Lengths', 
                fontsize=18, fontweight='bold', pad=25)
    
    # 设置网格
    ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
    ax.set_axisbelow(True)
    
    # 设置坐标轴范围和刻度
    ax.set_xlim(min(truncation_lengths) - 500, max(truncation_lengths) + 1000)
    ax.set_ylim(0, 1.05)
    
    # 设置x轴刻度
    ax.set_xticks(truncation_lengths)
    ax.set_xticklabels([f'{x:,}' for x in truncation_lengths], fontsize=12)
    
    # 设置y轴刻度
    ax.set_yticks(np.arange(0, 1.1, 0.2))
    ax.set_yticklabels([f'{x:.1f}' for x in np.arange(0, 1.1, 0.2)], fontsize=12)
    
    # 美化图例
    legend = ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', 
                      fontsize=12, frameon=True, fancybox=True, shadow=True,
                      title='Models', title_fontsize=14)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_alpha(0.9)
    
    # 如果需要设置标题字体粗细，可以通过以下方式：
    legend.get_title().set_fontweight('bold')
    
    # 智能添加数值标签 - 避免重叠
    added_labels = set()
    
    for model_name, accuracies in results.items():
        if not accuracies:
            continue
            
        for length in truncation_lengths:
            if length in accuracies:
                accuracy = accuracies[length]
                
                # 检查是否与已有标签重叠
                overlap = False
                for existing_x, existing_y in added_labels:
                    if abs(length - existing_x) < 1000 and abs(accuracy - existing_y) < 0.05:
                        overlap = True
                        break
                
                if not overlap:
                    # 根据位置调整标签位置
                    if accuracy > 0.8:
                        offset_y = -15
                    else:
                        offset_y = 15
                    
                    ax.annotate(f'{accuracy:.3f}', 
                              (length, accuracy),
                              textcoords="offset points", 
                              xytext=(0, offset_y), 
                              ha='center', 
                              fontsize=10,
                              fontweight='bold',
                              bbox=dict(boxstyle="round,pad=0.3", 
                                      facecolor='white', 
                                      edgecolor='gray', 
                                      alpha=0.8),
                              zorder=10)
                    added_labels.add((length, accuracy))
    
    # 添加参考线
    ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5, linewidth=1)
    ax.axhline(y=0.7, color='gray', linestyle=':', alpha=0.5, linewidth=1)
    
    # 调整布局
    plt.tight_layout()
    
    # 保存图片
    output_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/truncation_accuracy_plot_beautified.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    print(f"\n美化折线图已保存到: {output_path}")
    
    plt.show()
    
    # 如果有datasource结果，绘制按datasource分组的详细图表
    if datasource_results:
        plot_datasource_breakdown(datasource_results, truncation_lengths)
    
    # 如果有长度统计结果，绘制长度统计图表
    if length_results:
        plot_length_statistics(length_results, truncation_lengths)

def plot_datasource_breakdown(datasource_results: dict, truncation_lengths: list):
    """绘制按datasource分组的详细准确率图表"""
    
    # 收集所有datasource
    all_datasources = set()
    for model_results in datasource_results.values():
        for truncation_results in model_results.values():
            all_datasources.update(truncation_results.keys())
    
    if not all_datasources:
        print("没有找到datasource数据")
        return
    
    # 为每个datasource创建子图
    n_datasources = len(all_datasources)
    fig, axes = plt.subplots(2, (n_datasources + 1) // 2, figsize=(16, 12))
    if n_datasources == 1:
        axes = [axes]
    elif n_datasources <= 2:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E', '#7209B7', '#F77F00', '#FCBF49']
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p']
    
    for i, datasource in enumerate(sorted(all_datasources)):
        ax = axes[i]
        
        # 为每个模型绘制该datasource的准确率
        for j, (model_name, model_results) in enumerate(datasource_results.items()):
            accuracies = []
            for length in truncation_lengths:
                if length in model_results and datasource in model_results[length]:
                    accuracies.append(model_results[length][datasource])
                else:
                    accuracies.append(0.0)
            
            ax.plot(truncation_lengths, accuracies, 
                   marker=markers[j % len(markers)], 
                   linewidth=2, 
                   markersize=8,
                   color=colors[j % len(colors)],
                   label=model_name,
                   alpha=0.8)
        
        ax.set_title(f'Datasource: {datasource}', fontsize=14, fontweight='bold')
        ax.set_xlabel('Token Truncation Length', fontsize=12)
        ax.set_ylabel('Accuracy', fontsize=12)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=10)
        ax.set_ylim(0, 1.05)
    
    # 隐藏多余的子图
    for i in range(n_datasources, len(axes)):
        axes[i].set_visible(False)
    
    plt.suptitle('Accuracy by Datasource and Truncation Length', fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    # 保存图片
    output_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/datasource_breakdown_plot.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    print(f"\nDatasource分组图表已保存到: {output_path}")
    
    plt.show()

def plot_length_statistics(length_results: dict, truncation_lengths: list):
    """绘制长度统计图表"""
    
    # 创建单个图表：原始长度
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E', '#7209B7', '#F77F00', '#FCBF49']
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p']
    
    # 绘制原始长度统计
    for i, (model_name, model_lengths) in enumerate(length_results.items()):
        original_lengths = []
        for length in truncation_lengths:
            if length in model_lengths:
                original_lengths.append(model_lengths[length]['avg_original_length'])
            else:
                original_lengths.append(0)
        
        ax.plot(truncation_lengths, original_lengths, 
                marker=markers[i % len(markers)], 
                linewidth=3, 
                markersize=10,
                color=colors[i % len(colors)],
                label=model_name,
                alpha=0.9)
    
    ax.set_xlabel('Token Truncation Length', fontsize=14, fontweight='bold')
    ax.set_ylabel('Average Original Length (tokens)', fontsize=14, fontweight='bold')
    ax.set_title('Average Original Text Length by Truncation Limit', fontsize=16, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=12)
    ax.set_xticks(truncation_lengths)
    ax.set_xticklabels([f'{x:,}' for x in truncation_lengths], fontsize=10)
    
    plt.tight_layout()
    
    # 保存图片
    output_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/length_statistics_plot.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    print(f"\n长度统计图表已保存到: {output_path}")
    
    plt.show()

def save_results_to_csv(results: dict, truncation_lengths: list, datasource_results: dict = None, length_results: dict = None):
    """保存结果到CSV文件 - 支持Macro Average、按datasource分组的结果和长度统计"""
    
    # 创建Macro Average结果表格
    macro_data = []
    for model_name, accuracies in results.items():
        row = {'Model': model_name}
        for length in truncation_lengths:
            row[f'Macro_Avg_{length}'] = accuracies.get(length, 0)
        macro_data.append(row)
    
    macro_df = pd.DataFrame(macro_data)
    
    # 打印Macro Average结果表格
    print("\nMacro Average准确率结果表格:")
    print("=" * 80)
    print(macro_df.to_string(index=False, float_format='%.4f'))
    
    # 如果有datasource结果，创建详细的结果表格
    if datasource_results:
        print("\n按Datasource分组的详细结果:")
        print("=" * 80)
        
        # 收集所有datasource
        all_datasources = set()
        for model_results in datasource_results.values():
            for truncation_results in model_results.values():
                all_datasources.update(truncation_results.keys())
        
        # 为每个datasource创建表格
        for datasource in sorted(all_datasources):
            print(f"\nDatasource: {datasource}")
            print("-" * 40)
            
            datasource_data = []
            for model_name, model_results in datasource_results.items():
                row = {'Model': model_name}
                for length in truncation_lengths:
                    if length in model_results and datasource in model_results[length]:
                        row[f'Accuracy_{length}'] = model_results[length][datasource]
                    else:
                        row[f'Accuracy_{length}'] = 0.0
                datasource_data.append(row)
            
            datasource_df = pd.DataFrame(datasource_data)
            print(datasource_df.to_string(index=False, float_format='%.4f'))
    
    # 保存到CSV文件
    csv_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/truncation_accuracy_results.csv"
    macro_df.to_csv(csv_path, index=False)
    print(f"\nMacro Average结果已保存到: {csv_path}")
    
    if datasource_results:
        # 保存详细的datasource结果
        detailed_csv_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/datasource_breakdown_results.csv"
        
        # 创建详细的DataFrame
        detailed_data = []
        for model_name, model_results in datasource_results.items():
            for length in truncation_lengths:
                if length in model_results:
                    for datasource, accuracy in model_results[length].items():
                        detailed_data.append({
                            'Model': model_name,
                            'Truncation_Length': length,
                            'Datasource': datasource,
                            'Accuracy': accuracy
                        })
        
        detailed_df = pd.DataFrame(detailed_data)
        detailed_df.to_csv(detailed_csv_path, index=False)
        print(f"详细Datasource结果已保存到: {detailed_csv_path}")
    
    # 如果有长度统计结果，保存长度统计
    if length_results:
        print("\n长度统计结果:")
        print("=" * 80)
        
        # 创建长度统计表格
        length_data = []
        for model_name, model_lengths in length_results.items():
            row = {'Model': model_name}
            for length in truncation_lengths:
                if length in model_lengths:
                    row[f'Avg_Original_Length_{length}'] = model_lengths[length]['avg_original_length']
                   
                else:
                    row[f'Avg_Original_Length_{length}'] = 0
               
            length_data.append(row)
        
        length_df = pd.DataFrame(length_data)
        print(length_df.to_string(index=False, float_format='%.1f'))
        
        # 保存长度统计到CSV
        length_csv_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/length_statistics_results.csv"
        length_df.to_csv(length_csv_path, index=False)
        print(f"\n长度统计结果已保存到: {length_csv_path}")
        
        # 按datasource保存详细长度统计
        if any('datasource_avg_lengths' in model_lengths.get(length, {}) for model_lengths in length_results.values() for length in truncation_lengths):
            detailed_length_data = []
            for model_name, model_lengths in length_results.items():
                for length in truncation_lengths:
                    if length in model_lengths and 'datasource_avg_lengths' in model_lengths[length]:
                        for datasource, lengths in model_lengths[length]['datasource_avg_lengths'].items():
                            detailed_length_data.append({
                                'Model': model_name,
        
                                'Datasource': datasource,
                                'Avg_Original_Length': lengths['avg_original'],
                                
                            })
            
            if detailed_length_data:
                detailed_length_df = pd.DataFrame(detailed_length_data)
                detailed_length_csv_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/detailed_length_statistics_results.csv"
                detailed_length_df.to_csv(detailed_length_csv_path, index=False)
                print(f"详细长度统计结果已保存到: {detailed_length_csv_path}")

if __name__ == "__main__":
    print("开始分析不同模型在不同Token截断长度下的准确率...")
    print("=" * 60)
    
    try:
        results = analyze_models_truncation_accuracy()
        print("\n分析完成！")
        
    except KeyboardInterrupt:
        print("\n用户中断了程序")
        sys.exit(0)
    except Exception as e:
        print(f"\n程序执行过程中出现错误: {e}")
        import traceback
        traceback.print_exc()