import json
from transformers import AutoTokenizer
import statistics
import numpy as np
from collections import defaultdict
import os
import math

def analyze_generation_lengths(jsonl_file, model_name="microsoft/DialoGPT-medium", analyze_by_source=False):
    """
    分析JSONL文件中正确和错误case的生成长度
    
    Args:
        jsonl_file: JSONL文件路径
        model_name: 用于tokenization的模型名称
        analyze_by_source: 是否按data_source字段进行额外分析
    """
    
    # 初始化tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    correct_lengths = []
    incorrect_lengths = []
    all_data = []  # 存储所有数据点用于准确率分析
    
    # 按data_source分组的数据
    source_data = {}  # {source: {'lengths': [], 'correct': [], 'incorrect': []}}
    
    # 读取JSONL文件
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                data = json.loads(line.strip())
                
                # 获取生成的文本
                generated_text = data.get('generated_text', '')
                correctness = data.get('correctness', None)
                data_source = data.get('data_source', None)
                
                # 使用tokenizer计算长度
                tokens = tokenizer.encode(generated_text)
                token_length = len(tokens)
                
                # 存储所有数据点
                if correctness is not None:
                    all_data.append((token_length, correctness))
                
                # 根据正确性分类
                if correctness is True:
                    correct_lengths.append(token_length)
                elif correctness is False:
                    incorrect_lengths.append(token_length)
                else:
                    print(f"Warning: Line {line_num} has undefined correctness: {correctness}")
                
                # 按data_source分组统计
                if analyze_by_source and data_source is not None and correctness is not None:
                    if data_source not in source_data:
                        source_data[data_source] = {
                            'lengths': [],
                            'correct': [],
                            'incorrect': [],
                            'total_correct': 0,
                            'total_samples': 0
                        }
                    
                    source_data[data_source]['lengths'].append(token_length)
                    source_data[data_source]['total_samples'] += 1
                    
                    if correctness is True:
                        source_data[data_source]['correct'].append(token_length)
                        source_data[data_source]['total_correct'] += 1
                    elif correctness is False:
                        source_data[data_source]['incorrect'].append(token_length)
                    
            except json.JSONDecodeError:
                print(f"Error: Invalid JSON on line {line_num}")
            except Exception as e:
                print(f"Error processing line {line_num}: {e}")
    
    # 统计分析
    print("=" * 50)
    print("Generation Length Analysis")
    print("=" * 50)
    
    print(f"Total samples: {len(correct_lengths) + len(incorrect_lengths)}")
    print(f"Correct samples: {len(correct_lengths)}")
    print(f"Incorrect samples: {len(incorrect_lengths)}")
    print()
    
    if correct_lengths:
        print("Correct Cases:")
        print(f"  Mean length: {statistics.mean(correct_lengths):.2f} tokens")
        print(f"  Median length: {statistics.median(correct_lengths):.2f} tokens")
        print(f"  Min length: {min(correct_lengths)} tokens")
        print(f"  Max length: {max(correct_lengths)} tokens")
        print(f"  Std deviation: {statistics.stdev(correct_lengths):.2f} tokens")
        print()
    
    if incorrect_lengths:
        print("Incorrect Cases:")
        print(f"  Mean length: {statistics.mean(incorrect_lengths):.2f} tokens")
        print(f"  Median length: {statistics.median(incorrect_lengths):.2f} tokens")
        print(f"  Min length: {min(incorrect_lengths)} tokens")
        print(f"  Max length: {max(incorrect_lengths)} tokens")
        print(f"  Std deviation: {statistics.stdev(incorrect_lengths):.2f} tokens")
        print()
    
    # 比较分析
    if correct_lengths and incorrect_lengths:
        mean_diff = statistics.mean(incorrect_lengths) - statistics.mean(correct_lengths)
        print("Comparison:")
        print(f"  Average length difference (Incorrect - Correct): {mean_diff:.2f} tokens")
        
        if mean_diff > 0:
            print("  → Incorrect cases tend to be longer")
        elif mean_diff < 0:
            print("  → Correct cases tend to be longer")
        else:
            print("  → Similar average lengths")
        print()
    
    # 按data_source分析
    if analyze_by_source and source_data:
        print("=" * 50)
        print("Analysis by Data Source")
        print("=" * 50)
        
        for source, data in sorted(source_data.items()):
            accuracy = (data['total_correct'] / data['total_samples']) * 100 if data['total_samples'] > 0 else 0
            avg_length = statistics.mean(data['lengths']) if data['lengths'] else 0
            
            print(f"Data Source: {source}")
            print(f"  Total samples: {data['total_samples']}")
            print(f"  Correct samples: {data['total_correct']}")
            print(f"  Accuracy: {accuracy:.2f}%")
            print(f"  Average length: {avg_length:.2f} tokens")
            
            if data['correct']:
                print(f"  Correct cases avg length: {statistics.mean(data['correct']):.2f} tokens")
            if data['incorrect']:
                print(f"  Incorrect cases avg length: {statistics.mean(data['incorrect']):.2f} tokens")
            
            if data['correct'] and data['incorrect']:
                source_diff = statistics.mean(data['incorrect']) - statistics.mean(data['correct'])
                print(f"  Length difference (Incorrect - Correct): {source_diff:.2f} tokens")
            
            print()
        
        # 按准确率排序显示
        print("Sources ranked by accuracy:")
        sorted_sources = sorted(source_data.items(), 
                              key=lambda x: (x[1]['total_correct'] / x[1]['total_samples']) if x[1]['total_samples'] > 0 else 0, 
                              reverse=True)
        
        for i, (source, data) in enumerate(sorted_sources, 1):
            accuracy = (data['total_correct'] / data['total_samples']) * 100 if data['total_samples'] > 0 else 0
            avg_length = statistics.mean(data['lengths']) if data['lengths'] else 0
            print(f"  {i}. {source}: {accuracy:.2f}% accuracy, {avg_length:.2f} avg tokens ({data['total_samples']} samples)")
    
    return {
        'correct_lengths': correct_lengths,
        'incorrect_lengths': incorrect_lengths,
        'all_data': all_data,
        'source_data': source_data if analyze_by_source else None
    }


def calculate_accuracy_by_length(all_data, bin_size=10):
    """
    计算不同token长度区间的准确率
    
    Args:
        all_data: [(token_length, correctness), ...] 格式的数据
        bin_size: 每个区间的大小
    
    Returns:
        lengths: token长度区间的中心点
        accuracies: 对应的准确率
        counts: 每个区间的样本数
    """
    if not all_data:
        return [], [], []
    
    # 按长度分组
    length_groups = defaultdict(list)
    
    # 获取长度范围
    lengths = [item[0] for item in all_data]
    min_length = min(lengths)
    max_length = max(lengths)
    
    # 创建区间
    for token_length, correctness in all_data:
        # 计算属于哪个区间
        bin_index = (token_length - min_length) // bin_size
        bin_center = min_length + bin_index * bin_size + bin_size // 2
        length_groups[bin_center].append(correctness)
    
    # 计算每个区间的准确率
    bin_centers = []
    accuracies = []
    counts = []
    
    for bin_center in sorted(length_groups.keys()):
        correct_count = sum(length_groups[bin_center])
        total_count = len(length_groups[bin_center])
        accuracy = correct_count / total_count if total_count > 0 else 0
        
        bin_centers.append(bin_center)
        accuracies.append(accuracy)
        counts.append(total_count)
    
    return bin_centers, accuracies, counts

def plot_length_distribution_subplots(results_list, file_names, save_name="comparison"):
    """
    在多个子图中绘制长度分布图，支持任意数量的文件
    
    Args:
        results_list: 结果字典的列表
        file_names: 文件对应的名称列表
        save_name: 保存文件的名称
    """
    try:
        import matplotlib.pyplot as plt
        
        n_files = len(results_list)
        
        # 动态计算子图布局
        if n_files <= 3:
            rows, cols = 1, n_files
            figsize = (6 * n_files, 6)
        elif n_files <= 6:
            rows, cols = 2, 3
            figsize = (18, 12)
        elif n_files <= 9:
            rows, cols = 3, 3
            figsize = (18, 18)
        else:
            # 对于更多文件，计算合适的行列数
            cols = min(4, n_files)  # 最多4列
            rows = math.ceil(n_files / cols)
            figsize = (6 * cols, 6 * rows)
        
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        
        # 如果只有一个子图，确保axes是数组
        if n_files == 1:
            axes = [axes]
        elif rows == 1:
            axes = axes if isinstance(axes, (list, np.ndarray)) else [axes]
        else:
            axes = axes.flatten()
        
        for i, (results, file_name) in enumerate(zip(results_list, file_names)):
            if i >= len(axes):
                break
                
            correct_lengths = results['correct_lengths']
            incorrect_lengths = results['incorrect_lengths']
            
            axes[i].hist(correct_lengths, bins=30, alpha=0.6, color='green', 
                        label=f'Correct (n={len(correct_lengths)})')
            axes[i].hist(incorrect_lengths, bins=30, alpha=0.6, color='red', 
                        label=f'Incorrect (n={len(incorrect_lengths)})')
            
            axes[i].set_xlabel('Token Length')
            axes[i].set_ylabel('Frequency')
            axes[i].set_title(f'{file_name}\nLength Distribution')
            axes[i].legend()
            axes[i].grid(True, alpha=0.3)
        
        # 隐藏多余的子图
        for i in range(n_files, len(axes)):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        plt.savefig(f'plots/{save_name}_length_distribution_subplots.png', dpi=300, bbox_inches='tight')
        plt.show()
       
    except ImportError:
        print("matplotlib not available. Skipping plots.")

def plot_accuracy_vs_length_combined(results_list, file_names, save_name="comparison"):
    """
    在一个图中绘制三个子图：累积准确率、累积正确个数、累积错误个数
    
    Args:
        results_list: 结果字典的列表
        file_names: 文件对应的名称列表
        save_name: 保存文件的名称
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        
        n_files = len(results_list)
        
        # 创建三个子图
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 16))
        
        # 生成足够的颜色和标记样式
        colors = plt.cm.tab10(np.linspace(0, 1, min(n_files, 10)))  # tab10最多10种颜色
        if n_files > 10:
            # 如果文件数超过10个，使用更多颜色映射
            colors = plt.cm.tab20(np.linspace(0, 1, min(n_files, 20)))
        if n_files > 20:
            # 如果超过20个，使用连续色彩映射
            colors = plt.cm.viridis(np.linspace(0, 1, n_files))
        
        markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h', '+', 'x', '|', '_']
        # 如果标记不够，循环使用
        markers = [markers[i % len(markers)] for i in range(n_files)]
        
        for i, (results, file_name) in enumerate(zip(results_list, file_names)):
            all_data = results['all_data']
            
            if not all_data:
                print(f"No data available for {file_name}")
                continue
            
            # 按token长度排序
            sorted_data = sorted(all_data, key=lambda x: x[0])
            
            # 获取所有唯一的token长度
            unique_lengths = sorted(list(set([x[0] for x in all_data])))
            
            cumulative_accuracies = []
            cumulative_correct_counts = []
            cumulative_error_counts = []
            
            # 对每个token长度，计算≤该长度的所有样本的准确率和计数
            for length_threshold in unique_lengths:
                samples_within_threshold = [correctness for token_length, correctness in all_data 
                                          if token_length <= length_threshold]
                
                if samples_within_threshold:
                    accuracy = sum(samples_within_threshold) / len(samples_within_threshold)
                    correct_count = sum(samples_within_threshold)
                    error_count = len(samples_within_threshold) - correct_count
                    
                    cumulative_accuracies.append(accuracy)
                    cumulative_correct_counts.append(correct_count)
                    cumulative_error_counts.append(error_count)
                else:
                    cumulative_accuracies.append(0)
                    cumulative_correct_counts.append(0)
                    cumulative_error_counts.append(0)
            
            # 第一个子图：累积准确率
            ax1.plot(unique_lengths, cumulative_accuracies, 
                    color=colors[i], linewidth=2, marker=markers[i], 
                    markersize=4, label=file_name, alpha=0.8)
            
            # 第二个子图：累积正确个数
            ax2.plot(unique_lengths, cumulative_correct_counts, 
                    color=colors[i], linewidth=2, marker=markers[i], 
                    markersize=4, label=file_name, alpha=0.8)
            
            # 第三个子图：累积错误个数
            ax3.plot(unique_lengths, cumulative_error_counts, 
                    color=colors[i], linewidth=2, marker=markers[i], 
                    markersize=4, label=file_name, alpha=0.8)
            
            # 打印整体统计信息
            print(f"{file_name} - Overall accuracy: {cumulative_accuracies[-1]:.3f}, "
                  f"Total correct: {cumulative_correct_counts[-1]}, "
                  f"Total errors: {cumulative_error_counts[-1]}")
        
        # 设置第一个子图（累积准确率）
        ax1.set_xlabel('Token Length Threshold')
        ax1.set_ylabel('Cumulative Accuracy (≤ token length)')
        ax1.set_title('Cumulative Accuracy vs Token Length Threshold - Comparison')
        ax1.grid(True, alpha=0.3)
        ax1.set_ylim(0, 1)
        
        # 设置第二个子图（累积正确个数）
        ax2.set_xlabel('Token Length Threshold')
        ax2.set_ylabel('Cumulative Correct Count (≤ token length)')
        ax2.set_title('Cumulative Correct Count vs Token Length Threshold - Comparison')
        ax2.grid(True, alpha=0.3)
        
        # 设置第三个子图（累积错误个数）
        ax3.set_xlabel('Token Length Threshold')
        ax3.set_ylabel('Cumulative Error Count (≤ token length)')
        ax3.set_title('Cumulative Error Count vs Token Length Threshold - Comparison')
        ax3.grid(True, alpha=0.3)
        
        # 如果文件太多，调整图例显示
        if n_files > 10:
            ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            ax1.legend()
            ax2.legend()
            ax3.legend()
        
        plt.tight_layout()
        plt.savefig(f'plots/{save_name}_cumulative_analysis_combined.png', dpi=300, bbox_inches='tight')
        plt.show()
        
    except ImportError:
        print("matplotlib not available. Skipping accuracy vs length plot.")
    except Exception as e:
        print(f"Error creating accuracy vs length plot: {e}")


def plot_correct_error_counts_combined(results_list, file_names, save_name="comparison"):
    """
    单独绘制正确个数和错误个数的累积曲线图
    
    Args:
        results_list: 结果字典的列表
        file_names: 文件对应的名称列表
        save_name: 保存文件的名称
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        
        n_files = len(results_list)
        
        # 创建两个子图
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))
        
        # 生成足够的颜色和标记样式
        colors = plt.cm.tab10(np.linspace(0, 1, min(n_files, 10)))
        if n_files > 10:
            colors = plt.cm.tab20(np.linspace(0, 1, min(n_files, 20)))
        if n_files > 20:
            colors = plt.cm.viridis(np.linspace(0, 1, n_files))
        
        markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h', '+', 'x', '|', '_']
        markers = [markers[i % len(markers)] for i in range(n_files)]
        
        for i, (results, file_name) in enumerate(zip(results_list, file_names)):
            all_data = results['all_data']
            
            if not all_data:
                print(f"No data available for {file_name}")
                continue
            
            # 获取所有唯一的token长度
            unique_lengths = sorted(list(set([x[0] for x in all_data])))
            
            cumulative_correct_counts = []
            cumulative_error_counts = []
            
            # 对每个token长度，计算≤该长度的所有样本的计数
            for length_threshold in unique_lengths:
                samples_within_threshold = [correctness for token_length, correctness in all_data 
                                          if token_length <= length_threshold]
                
                if samples_within_threshold:
                    correct_count = sum(samples_within_threshold)
                    error_count = len(samples_within_threshold) - correct_count
                    
                    cumulative_correct_counts.append(correct_count)
                    cumulative_error_counts.append(error_count)
                else:
                    cumulative_correct_counts.append(0)
                    cumulative_error_counts.append(0)
            
            # 第一个子图：累积正确个数
            ax1.plot(unique_lengths, cumulative_correct_counts, 
                    color=colors[i], linewidth=2, marker=markers[i], 
                    markersize=4, label=file_name, alpha=0.8)
            
            # 第二个子图：累积错误个数
            ax2.plot(unique_lengths, cumulative_error_counts, 
                    color=colors[i], linewidth=2, marker=markers[i], 
                    markersize=4, label=file_name, alpha=0.8)
        
        # 设置第一个子图（累积正确个数）
        ax1.set_xlabel('Token Length Threshold')
        ax1.set_ylabel('Cumulative Correct Count (≤ token length)')
        ax1.set_title('Cumulative Correct Count vs Token Length Threshold - Comparison')
        ax1.grid(True, alpha=0.3)
        
        # 设置第二个子图（累积错误个数）
        ax2.set_xlabel('Token Length Threshold')
        ax2.set_ylabel('Cumulative Error Count (≤ token length)')
        ax2.set_title('Cumulative Error Count vs Token Length Threshold - Comparison')
        ax2.grid(True, alpha=0.3)
        
        # 如果文件太多，调整图例显示
        if n_files > 10:
            ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            ax1.legend()
            ax2.legend()
        
        plt.tight_layout()
        plt.savefig(f'plots/{save_name}_correct_error_counts_combined.png', dpi=300, bbox_inches='tight')
        plt.show()
        
    except ImportError:
        print("matplotlib not available. Skipping correct/error counts plot.")
    except Exception as e:
        print(f"Error creating correct/error counts plot: {e}")

def plot_accuracy_by_exponential_intervals(results_list, file_names, base=2, save_name="comparison"):
    """
    使用指数增长区间绘制区间准确率：[1-2), [2-4), [4-8), [8-16), ...
    
    Args:
        results_list: 结果字典的列表
        file_names: 文件对应的名称列表
        base: 指数的底数，默认为2
        save_name: 保存文件的名称
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
        import numpy as np
        
        n_files = len(results_list)
        
        # 创建两个子图
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12))
        
        # 生成颜色和标记
        colors = plt.cm.tab10(np.linspace(0, 1, min(n_files, 10)))
        if n_files > 10:
            colors = plt.cm.tab20(np.linspace(0, 1, min(n_files, 20)))
        if n_files > 20:
            colors = plt.cm.viridis(np.linspace(0, 1, n_files))
        
        markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h', '+', 'x', '|', '_']
        markers = [markers[i % len(markers)] for i in range(n_files)]
        
        # 找到最大token长度
        max_length = 0
        for results in results_list:
            if results['all_data']:
                file_max = max([x[0] for x in results['all_data']])
                max_length = max(max_length, file_max)
        
        # 生成指数增长的区间边界：1, 2, 4, 8, 16, 32, ...
        interval_boundaries = [0, 256]
        base = 2
        while interval_boundaries[-1] * base <= max_length:
            interval_boundaries.append(interval_boundaries[-1] * base)

        # 创建区间：[1-2), [2-4), [4-8), [8-16), ...
        intervals = []
        interval_labels = []
        for i in range(len(interval_boundaries) - 1):
            start = interval_boundaries[i]
            end = interval_boundaries[i + 1]
            intervals.append((start, end))
            interval_labels.append(f"[{start}-{end})")
        
        # 添加最后一个区间到最大长度
        if interval_boundaries[-1] < max_length:
            intervals.append((interval_boundaries[-1], max_length + 1))
            interval_labels.append(f"[{interval_boundaries[-1]}-{max_length}]")
        
        print(f"Using exponential intervals (base={base}): {intervals}")
        
        for i, (results, file_name) in enumerate(zip(results_list, file_names)):
            all_data = results['all_data']
            
            if not all_data:
                print(f"No data available for {file_name}")
                continue
            
            interval_accuracies = []
            interval_counts = []
            
            print(f"\n{file_name} - Exponential Interval Statistics:")
            print("-" * 60)
            
            # 计算每个区间的准确率
            for (start, end), label in zip(intervals, interval_labels):
                samples_in_interval = [correctness for token_length, correctness in all_data 
                                     if start <= token_length < end]
                
                if samples_in_interval:
                    accuracy = sum(samples_in_interval) / len(samples_in_interval)
                    count = len(samples_in_interval)
                    correct_count = sum(samples_in_interval)
                    error_count = count - correct_count
                    
                    interval_accuracies.append(accuracy)
                    interval_counts.append(count)
                    
                    print(f"Length {label:12s}: Acc={accuracy:.3f}, "
                          f"Total={count:4d}, Correct={correct_count:4d}, Error={error_count:4d}")
                else:
                    interval_accuracies.append(0)
                    interval_counts.append(0)
                    print(f"Length {label:12s}: No samples")
            
            # 绘制准确率
            x_positions = range(len(interval_labels))
            ax1.plot(x_positions, interval_accuracies, 
                    color=colors[i], linewidth=2, marker=markers[i], 
                    markersize=6, label=f"{file_name}", alpha=0.8)
            
            # 绘制样本数量
            ax2.plot(x_positions, interval_counts, 
                    color=colors[i], linewidth=2, marker=markers[i], 
                    markersize=6, label=f"{file_name}", alpha=0.8)
        
        # 设置第一个子图（准确率）
        ax1.set_xlabel('Token Length Intervals (Exponential)')
        ax1.set_ylabel('Accuracy')
        ax1.set_title(f'Accuracy by Exponential Token Length Intervals (base={base})')
        ax1.set_xticks(range(len(interval_labels)))
        ax1.set_xticklabels(interval_labels, rotation=45)
        ax1.grid(True, alpha=0.3)
        ax1.set_ylim(0, 1)
        
        # 设置第二个子图（样本数量）
        ax2.set_xlabel('Token Length Intervals (Exponential)')
        ax2.set_ylabel('Sample Count')
        ax2.set_title(f'Sample Count by Exponential Token Length Intervals (base={base})')
        ax2.set_xticks(range(len(interval_labels)))
        ax2.set_xticklabels(interval_labels, rotation=45)
        ax2.grid(True, alpha=0.3)
        
        # 设置图例
        if n_files > 10:
            ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            ax1.legend()
            ax2.legend()
        
        plt.tight_layout()
        plt.savefig(f'plots/{save_name}_accuracy_exponential_intervals.png', dpi=300, bbox_inches='tight')
        plt.show()
        
    except ImportError:
        print("matplotlib not available. Skipping exponential interval plot.")
    except Exception as e:
        print(f"Error creating exponential interval plot: {e}")


def plot_model_comparison_all_questions(results_list, file_names, save_name="comparison"):
    """
    展示不同模型在所有问题上的正确性对比 - 大字体版本
    
    Args:
        results_list: 结果字典的列表
        file_names: 文件对应的名称列表
        save_name: 保存文件的名称
    """
    try:
        import matplotlib.pyplot as plt
        import numpy as np
        
        print(f"Processing all questions in dataset...")
        
        # 收集所有模型的数据
        all_questions = {}
        
        for results, file_name in zip(results_list, file_names):
            if 'all_data' not in results or not results['all_data']:
                print(f"No data available for {file_name}")
                continue
                
            if 'detailed_results' in results:
                for i, item in enumerate(results['detailed_results']):
                    index = item.get('index', i)
                    correctness = item.get('correct', False)
                    
                    if index not in all_questions:
                        all_questions[index] = {}
                    all_questions[index][file_name] = correctness
            else:
                for i, (_, correctness) in enumerate(results['all_data']):
                    if i not in all_questions:
                        all_questions[i] = {}
                    all_questions[i][file_name] = correctness
        
        if not all_questions:
            print("No data to plot")
            return
        
        question_indices = sorted(all_questions.keys())
        total_questions = len(question_indices)
        n_models = len(file_names)
        
        print(f"Total questions: {total_questions}")
        print(f"Models: {n_models}")
        
        # 根据问题数量动态调整布局 - 增大图片尺寸和字体
        if total_questions <= 100:
            # 少于100个问题：使用2行布局
            questions_per_row = 50
            n_rows = 2 if total_questions > 50 else 1
            fig_width = 35  # 增大
            fig_height = 8  # 增大
            symbol_size = 10  # 增大
            x_fontsize = 12
            y_fontsize = 14
            title_fontsize = 16
        elif total_questions <= 500:
            # 100-500个问题：使用4行布局
            questions_per_row = total_questions // 4 + (1 if total_questions % 4 else 0)
            n_rows = 4
            fig_width = 40  # 增大
            fig_height = 6   # 增大
            symbol_size = 8  # 增大
            x_fontsize = 10
            y_fontsize = 12
            title_fontsize = 14
        elif total_questions <= 1000:
            # 500-1000个问题：使用5行布局
            questions_per_row = total_questions // 5 + (1 if total_questions % 5 else 0)
            n_rows = 5
            fig_width = 45  # 增大
            fig_height = 5   # 增大
            symbol_size = 6  # 增大
            x_fontsize = 9
            y_fontsize = 11
            title_fontsize = 13
        else:
            # 超过1000个问题：使用10行布局
            questions_per_row = total_questions // 10 + (1 if total_questions % 10 else 0)
            n_rows = 10
            fig_width = 50  # 增大
            fig_height = 4   # 增大
            symbol_size = 0  # 不显示符号，太密集
            x_fontsize = 8
            y_fontsize = 10
            title_fontsize = 12
        
        fig, axes = plt.subplots(n_rows, 1, figsize=(fig_width, fig_height * n_rows))
        if n_rows == 1:
            axes = [axes]
        
        for row_idx in range(n_rows):
            start_idx = row_idx * questions_per_row
            end_idx = min((row_idx + 1) * questions_per_row, total_questions)
            row_questions = question_indices[start_idx:end_idx]
            
            if not row_questions:
                break
                
            ax = axes[row_idx]
            
            # 准备数据矩阵
            n_questions_row = len(row_questions)
            correctness_matrix = np.full((n_models, n_questions_row), np.nan)
            
            for q_idx, question_index in enumerate(row_questions):
                question_data = all_questions[question_index]
                for m_idx, model_name in enumerate(file_names):
                    if model_name in question_data:
                        correctness_matrix[m_idx, q_idx] = 1 if question_data[model_name] else 0
            
            # 绘制热力图
            im = ax.imshow(correctness_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
            
            # 设置坐标轴标签 - 使用更大的字体
            if n_questions_row <= 50:
                step = max(1, n_questions_row // 8)  # 减少标签数量，避免重叠
            elif n_questions_row <= 200:
                step = max(1, n_questions_row // 12)
            else:
                step = max(1, n_questions_row // 15)
                
            ax.set_xticks(range(0, n_questions_row, step))
            ax.set_xticklabels([f'Q{row_questions[i]}' for i in range(0, n_questions_row, step)], 
                              rotation=45, ha='right', fontsize=x_fontsize, weight='bold')
            ax.set_yticks(range(n_models))
            
            # Y轴标签处理 - 增大字体并优化显示
            y_labels = []
            for name in file_names:
                if len(name) > 30:
                    y_labels.append(name[:27] + '...')
                else:
                    y_labels.append(name)
            
            ax.set_yticklabels(y_labels, fontsize=y_fontsize, weight='bold')
            
            # 添加网格线 - 使用更粗的线
            ax.set_xticks(np.arange(-0.5, n_questions_row, 1), minor=True)
            ax.set_yticks(np.arange(-0.5, n_models, 1), minor=True)
            ax.grid(which="minor", color="white", linestyle='-', linewidth=1)  # 增粗网格线
            
            # 根据问题数量决定是否添加符号标注
            if symbol_size > 0:
                for m_idx in range(n_models):
                    for q_idx in range(n_questions_row):
                        if not np.isnan(correctness_matrix[m_idx, q_idx]):
                            symbol = '✓' if correctness_matrix[m_idx, q_idx] == 1 else '✗'
                            color = 'white' if correctness_matrix[m_idx, q_idx] == 0 else 'darkgreen'
                            ax.text(q_idx, m_idx, symbol, ha='center', va='center', 
                                   color=color, fontsize=symbol_size, weight='bold')
            
            # 设置标题 - 增大字体
            if n_rows > 1:
                ax.set_title(f'Questions {row_questions[0]}-{row_questions[-1]} ({len(row_questions)} questions)', 
                           fontsize=title_fontsize, pad=15, weight='bold')
            else:
                ax.set_title(f'Model Comparison - All {total_questions} Questions', 
                           fontsize=title_fontsize + 2, pad=15, weight='bold')
            
            if row_idx == n_rows - 1:  # 只在最后一行添加x轴标签
                ax.set_xlabel('Question Index', fontsize=x_fontsize + 2, weight='bold')
            ax.set_ylabel('Models', fontsize=y_fontsize + 2, weight='bold')
        
        # 添加颜色条 - 增大字体
        # cbar = plt.colorbar(im, ax=axes, orientation='horizontal', pad=0.15, shrink=0.8)
        # cbar.set_label('Correctness (Green=Correct, Red=Incorrect)', 
                    #   fontsize=title_fontsize, weight='bold')
        # cbar.set_ticks([0, 1])
        # cbar.set_ticklabels(['Incorrect', 'Correct'], fontsize=x_fontsize + 2, weight='bold')
        
        # 在图的顶部添加总体统计信息 - 增大字体
        fig.suptitle(f'Model Performance Comparison\n'
                    f'Dataset: {total_questions} questions, {n_models} models\n'
                    f'Green=Correct, Red=Incorrect', 
                    fontsize=title_fontsize + 4, y=0.98, weight='bold')
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.90)  # 为suptitle留出更多空间
        
        # 保存为高分辨率图片
        plt.savefig(f'plots/{save_name}_all_questions_comparison_large_font.png', 
                   dpi=400, bbox_inches='tight', facecolor='white')  # 增加DPI
        plt.show()
        
        print(f"\nVisualization complete! Saved as: plots/{save_name}_all_questions_comparison_large_font.png")
        print(f"Layout: {n_rows} rows, ~{questions_per_row} questions per row")
        print(f"Font sizes - X-axis: {x_fontsize}, Y-axis: {y_fontsize}, Title: {title_fontsize}")
        
    except Exception as e:
        print(f"Error creating model comparison plot: {e}")
        import traceback
        traceback.print_exc()


# 简化的使用函数
def quick_full_comparison_large_font(results_list, file_names, save_name="full_dataset_large"):
    """快速生成全数据集的对比可视化 - 大字体版本"""
    plot_model_comparison_all_questions(results_list, file_names, save_name)




def analyze_multiple_files(jsonl_files, models, file_names, save_name="comparison"):
    """
    分析多个文件并生成对比图表
    
    Args:
        jsonl_files: JSONL文件路径列表
        models: 对应的模型名称列表
        file_names: 文件显示名称列表
        save_name: 保存图片的名称前缀
    """
    # 验证输入长度一致性
    if not (len(jsonl_files) == len(models) == len(file_names)):
        raise ValueError("jsonl_files, models, and file_names must have the same length")
    
    # 创建plots目录（如果不存在）
    os.makedirs('plots', exist_ok=True)
    
    # 分析所有文件
    all_results = []
    for jsonl_file, model_name, file_name in zip(jsonl_files, models, file_names):
        print(f"\n{'='*60}")
        print(f"Analyzing: {file_name}")
        print(f"{'='*60}")
        
        results = analyze_generation_lengths(jsonl_file, model_name, analyze_by_source=True)
        all_results.append(results)
    
    # 绘制长度分布图（多个子图）
    plot_length_distribution_subplots(all_results, file_names, save_name=save_name)
    
    # 绘制准确率vs长度的关系图（同一坐标轴）
    plot_accuracy_vs_length_combined(all_results, file_names, save_name=save_name)
    
    plot_accuracy_by_exponential_intervals(all_results, file_names)
    plot_model_comparison_all_questions(all_results, file_names, save_name)
    
    
    return all_results


# 使用示例
if __name__ == "__main__":
    # 文件路径和模型配置
    jsonl_files = [
       
        "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/Qwen2.5-7b-base-16k_test.jsonl",
        "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/Qwen-2.5-Instruct_test.jsonl",
        # "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/Qwen3-1.7B_test.jsonl", 
        "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/Qwen3-1.7B-own_test.jsonl",
        # "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/Qwen3-1.7B-Base_test.jsonl",
        "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/Qwen3-1.7B-Base-qwen-template_test.jsonl",
         "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/grpo_n8_16k_test.jsonl",
        #  "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/grpo-16k-n8-test_test.jsonl",
        # "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/grpo-16k-n8-tp1.0_test.jsonl",
        
        
        #  "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/grpo-16k-n8-tp0.6-8k_test.jsonl",
        #  "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/grpo-16k-n8-tp1.0-8k_test.jsonl",
        "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/grpo_n16_16k_test.jsonl",
        "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/e3-1.7B_test.jsonl",
        "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/RISE_1k_math_self_distilled_test.jsonl",
        "/fs-computility/xuxingcheng/wangfuting/rl/LUFFY/old_version/results/LUFFY-Qwen-Math-7B-Zero_test.jsonl"
    ]
    
    models = [
        
        "/fs-computility/ai-shen/shared/hf-hub/Qwen2.5-Math-7B-16k-think",
        "/fs-computility/ai-shen/shared/hf-hub/models--Qwen--Qwen2.5-7B-Instruct",
        "/fs-computility/ai-shen/shared/hf-hub/Qwen3-1.7B",
         "/fs-computility/ai-shen/shared/hf-hub/Qwen3-1.7B-Base",
        "/fs-computility/xuxingcheng/shared/lmreason/futingwang/models/luffy/luffy-math-test/ON_POLICY_TEST/best/actor",
        "/fs-computility/xuxingcheng/shared/lmreason/futingwang/models/luffy/luffy-math-test/ON_POLICY_TEST/best/actor",
        
        "/fs-computility/xuxingcheng/shared/lmreason/futingwang/models/e3-1.7B",
        "/fs-computility/xuxingcheng/shared/lmreason/futingwang/models/RISE/math_self-distilled_1k",
        "/fs-computility/ai-shen/shared/hf-hub/LUFFY-Qwen-Math-7B-Zero"
    ]
    
    # 文件名称（用于图例和标题）
    file_names = [
        
        "Qwen2.5 Base", 
        "Qwen2.5 Instruct",
        "Qwen3 1.7B",
        "Qwen3-1.7B-Base",
        "GRPO n8",
        "GRPO n16",
        "e3-1.7B",
        "RISE_1k_math_self_distilled",
        "LUFFY-Qwen-Math-7B-Zero"
    ]
    
    analyze_multiple_files(jsonl_files, models, file_names)
