"""
按难度对比多个模型在AIME24数据集上的准确率和响应长度

使用方法:
1. 修改 main() 函数中的 models_config
2. 运行脚本: python acc_by_difficulty.py
"""

import csv
import json
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from collections import defaultdict
from pathlib import Path


def detect_repetition_with_hash(text, window_size=10, max_repetitions_limit=6):
    """
    Use hashing to efficiently detect repeated n-grams (split by space and underscore).
    Returns -1 if any specific n-gram repeats more than 6 times, otherwise 0.
    """
    # Split text by both space and underscore
    words = []
    for segment in text.split():
        words.extend(segment.split('_'))

    if len(words) <= window_size:
        return 0

    hash_counts = {}
    max_repetitions = 0

    for i in range(len(words) - window_size + 1):
        # Get window and its hash
        window = tuple(words[i:i+window_size])
        window_hash = hash(window)

        # Update count for this hash
        hash_counts[window_hash] = hash_counts.get(window_hash, 0) + 1

        # Update max repetitions and early exit if threshold crossed
        if hash_counts[window_hash] > max_repetitions:
            max_repetitions = hash_counts[window_hash]
            if max_repetitions >= max_repetitions_limit:
                return -1

    return 0


plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 16  # 增大默认字体大小

# TOKENIZER_PATH = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-4B-Base"
TOKENIZER_PATH = "/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B"
AIME24_REFERENCE_PATH = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/aime24_nofigures.jsonl"
token_len_fn = None

# 难度分类 (使用 aime24_nofigures.jsonl 中的位置索引: 0-29)
DIFFICULTY_CLASSES = {
    "Easy": [9, 0, 7, 12],
    "Medium": [24, 11, 8, 15, 26, 6, 19, 18, 23, 22],
    "Hard": [10, 14, 17, 27, 16, 25, 4, 5, 1, 20, 28, 13],
    "Extremely Hard": [29, 2, 21, 3]
}

QUESTION_TO_DIFFICULTY = {}
for difficulty, question_ids in DIFFICULTY_CLASSES.items():
    for qid in question_ids:
        QUESTION_TO_DIFFICULTY[qid] = difficulty

DIFFICULTY_ORDER = ["Easy", "Medium", "Hard", "Extremely Hard"]
SUMMARY_ORDER = DIFFICULTY_ORDER + ["Overall"]

COLOR_BASELINE = "#00468B"
COLOR_GSPO_LENGTH = "#9B59B6"
COLOR_GSPO = "#2E8B57"
FALLBACK_SERIES_COLORS = [
    "#C97B84",
    "#6C91BF",
    "#C9A227",
    "#5C4B8A",
    "#2C7A7B",
]


def load_aime24_reference():
    """加载 AIME24 参考文件，返回问题列表和 ID 映射"""
    reference_problems = read_jsonl(AIME24_REFERENCE_PATH)
    # 创建从位置索引到实际ID的映射
    idx_to_id = {i: prob['id']-60 for i, prob in enumerate(reference_problems)}
    # 创建完整的 problem 到 idx 的映射（用于精确匹配）
    problem_to_idx = {prob['problem']: i for i,
                      prob in enumerate(reference_problems)}
    # 创建 problem 列表（用于模糊匹配）
    problems_list = [(prob['problem'], i)
                     for i, prob in enumerate(reference_problems)]

    return reference_problems, idx_to_id, problem_to_idx, problems_list


def build_token_len_fn(tokenizer_name_or_path):
    """优先使用 transformers，失败时退回到 tokenizers。"""
    try:
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path, trust_remote_code=True)

        def _token_len_fn(text):
            return len(tokenizer.encode(text, add_special_tokens=False))

        return _token_len_fn, "transformers.AutoTokenizer"
    except Exception as e_transformers:
        try:
            from tokenizers import Tokenizer

            tokenizer_json = Path(tokenizer_name_or_path) / "tokenizer.json"
            if not tokenizer_json.exists():
                raise FileNotFoundError(
                    f"tokenizer.json not found under {tokenizer_name_or_path}")

            tokenizer = Tokenizer.from_file(str(tokenizer_json))

            def _token_len_fn(text):
                return len(tokenizer.encode(text).ids)

            return _token_len_fn, "tokenizers.Tokenizer"
        except Exception as e_tokenizers:
            raise RuntimeError(
                "Failed to load tokenizer with both transformers and tokenizers.\n"
                f"transformers error: {e_transformers}\n"
                f"tokenizers error: {e_tokenizers}"
            )


def read_jsonl(file_path):
    """读取 JSONL 文件"""
    with open(file_path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f if line.strip()]


def extract_problem_from_prompt(prompt):
    """从 prompt 中提取问题部分"""
    problem = prompt.split("<|im_start|>user\n")[1].split("Let's think step by step and output the final answer within \\boxed{}.")[0].split(
        "\n[asy] import graph; unitsize(0.1cm); pair A = (0,0);pair B = (70,0);pair C = (70,16);pair D = (0,16);pair E = (3,16);pair F = (90,16);pair G = (90,33);pair H = (3,33); dot(A^^B^^C^^D^^E^^F^^G^^H); label(\"$A$\", A, S);label(\"$B$\", B, S);label(\"$C$\", C, N);label(\"$D$\", D, N);label(\"$E$\", E, S);label(\"$F$\", F, S);label(\"$G$\", G, N);label(\"$H$\", H, N); draw(E--D--A--B--C--E--H--G--F--C); [/asy]")[0].strip()
    return problem


def match_problem_to_reference(prompt, problems_list, problem_to_idx):
    """将结果文件中的 prompt 匹配到参考文件中的问题索引"""
    # 提取问题部分
    problem_text = extract_problem_from_prompt(prompt)

    # 先尝试精确匹配
    if problem_text in problem_to_idx:
        return problem_to_idx[problem_text]

    # 使用模糊匹配找最相似的
    from difflib import SequenceMatcher
    best_match_idx = None
    best_ratio = 0.0

    for ref_problem, idx in problems_list:
        ratio = SequenceMatcher(None, problem_text, ref_problem).ratio()
        if ratio > best_ratio:
            best_ratio = ratio
            best_match_idx = idx

    # 如果相似度超过阈值，返回匹配结果
    if best_ratio > 0.7:
        return best_match_idx
    breakpoint()
    # 无法匹配
    return None


def process_aime_data(jsonl_file, calculate_length=False, calculate_repetition=False):
    """处理 AIME 数据，按每8条计算准确率、平均长度或重复率
    数据格式：(问题0的8条, 问题1的8条, ..., 问题N的8条) * 16
    即整个序列重复16次，每个问题在16个不同位置出现，每次8条
    """
    global token_len_fn

    # 加载参考文件
    reference_problems, idx_to_id, problem_to_idx, problems_list = load_aime24_reference()

    all_data = read_jsonl(jsonl_file)
    aime_data = [item for item in all_data if item.get(
        'data_source') == 'aime']

    if calculate_length and token_len_fn is None:
        print(f"正在加载 tokenizer...")
        token_len_fn, tokenizer_backend = build_token_len_fn(TOKENIZER_PATH)
        print(f"tokenizer backend: {tokenizer_backend}")

    # 数据格式：(问题0的8条, 问题1的8条, ..., 问题N的8条) * 16
    # 每轮有 num_questions * 8 条数据，共16轮
    # breakpoint()
    num_rounds = len(aime_data) // 30 // 32
    items_per_round = len(aime_data) // num_rounds
    num_questions = 30

    print(
        f"总数据条数: {len(aime_data)}, 共 {num_rounds} 轮, 每轮 {items_per_round} 条, 共 {num_questions} 个不同的问题")

    question_metrics = {}
    matched_count = 0
    unmatched_count = 0
    unmatched_info = []

    for question_idx in range(num_questions):
        # 收集该问题在所有16轮中的数据
        all_rollouts = []
        for round_idx in range(num_rounds):
            # 在每轮中，问题question_idx的位置是：round_idx * items_per_round + question_idx * 8
            # round_start = round_idx * items_per_round
            # question_start = round_start + question_idx * 8
            # question_end = question_start + 8
            question_start = question_idx * 32
            question_end = question_start + 32
            all_rollouts.extend(aime_data[question_start:question_end])

        # 从第一条数据中获取 prompt 用于匹配
        first_item = all_rollouts[0]
        prompt = first_item.get('prompt', '')

        # 匹配到参考文件中的位置索引
        question_id = match_problem_to_reference(
            prompt, problems_list, problem_to_idx)

        if question_id is None:
            unmatched_count += 1
            problem_text = extract_problem_from_prompt(prompt)
            unmatched_info.append((question_idx, idx_to_id.get(
                question_idx, question_idx), problem_text[:150]))
            print(
                f"⚠️  问题 {question_idx} 无法匹配 | Prompt前150字符: {problem_text[:150]}")
            continue

        matched_count += 1
        print(
            f"✓ 问题 {question_idx} -> 匹配到参考文件索引 {question_id} (ID: {idx_to_id[question_id]})")

        if calculate_length:
            # 计算所有128条数据的平均长度（16轮 * 8条 = 128条）
            total = sum(token_len_fn(item.get('generated_text', '') or item.get('output', ''))
                        for item in all_rollouts)
            question_metrics[question_id] = total / len(all_rollouts)
        elif calculate_repetition:
            # 计算所有128条数据的重复率（16轮 * 8条 = 128条）
            # detect_repetition_with_hash 返回 -1 表示有重复，0 表示无重复
            repetition_count = sum(1 for item in all_rollouts
                                   if detect_repetition_with_hash(item.get('generated_text', '') or item.get('output', '')) == -1)
            question_metrics[question_id] = repetition_count / \
                len(all_rollouts)
        else:
            # 计算所有128条数据的准确率（16轮 * 8条 = 128条）
            question_metrics[question_id] = sum(
                item.get('correctness', False) for item in all_rollouts) / len(all_rollouts)

    print(f"\n{'='*80}")
    print(f"匹配结果: {matched_count}/{matched_count + unmatched_count} 个问题成功匹配")
    print(f"{'='*80}")

    if unmatched_count > 0:
        print(f"\n⚠️  警告: 有 {unmatched_count} 个问题未能匹配")
        print("未匹配的问题:")
        for idx, qid, snippet in unmatched_info:
            print(f"  - 问题 {idx} (期望ID {qid}): {snippet}...")

    return question_metrics


def sort_questions_by_difficulty(question_metrics):
    """按照难度排序问题ID"""
    questions_by_difficulty = defaultdict(list)
    for qid, value in question_metrics.items():
        difficulty = QUESTION_TO_DIFFICULTY.get(qid, "Unknown")
        questions_by_difficulty[difficulty].append((qid, value))

    sorted_questions = []
    for difficulty in DIFFICULTY_ORDER:
        if difficulty in questions_by_difficulty:
            questions = sorted(
                questions_by_difficulty[difficulty], key=lambda x: x[1], reverse=True)
            sorted_questions.extend([(qid, val, difficulty)
                                    for qid, val in questions])

    return sorted_questions


def resolve_series_style(model_name, series_idx):
    """对齐 baseline 图的颜色约定和线型风格。"""
    label = model_name.lower()

    if "qwen3-4b-base" in label or "baseline" in label:
        color = COLOR_BASELINE
    elif "lie" in label or "skip-right" in label or "add1k" in label:
        color = COLOR_GSPO_LENGTH
    elif "gspo" in label:
        color = COLOR_GSPO
    else:
        color = FALLBACK_SERIES_COLORS[series_idx %
                                       len(FALLBACK_SERIES_COLORS)]

    return {
        "color": color,
        "linestyle": "-",
        "linewidth": 3.2,
        "alpha": 0.98,
    }


def plot_combined_comparison(accuracy_data, length_data, repetition_data=None, output_path=None):
    """绘制准确率、平均长度和重复率的组合对比图"""
    if not accuracy_data:
        print("没有数据可绘制")
        return

    # 加载 AIME24 参考文件以获取实际的问题ID
    reference_problems, idx_to_id, problem_to_idx, problems_list = load_aime24_reference()

    # 使用第一个模型的数据确定问题ID和难度排序
    first_model = list(accuracy_data.keys())[0]
    sorted_questions = accuracy_data[first_model]
    question_indices = [q[0] for q in sorted_questions]  # 位置索引 (0-29)
    difficulties = [q[2] for q in sorted_questions]

    # 将位置索引转换为 aime24_nofigures.jsonl 中的实际ID
    question_ids = [idx_to_id[idx] for idx in question_indices]

    # 计算难度区域边界
    difficulty_boundaries = {}
    current_pos = 0
    for difficulty in DIFFICULTY_ORDER:
        count = sum(1 for d in difficulties if d == difficulty)
        if count > 0:
            difficulty_boundaries[difficulty] = (
                current_pos, current_pos + count - 1)
            current_pos += count

    # 配置
    bg_colors = {"Easy": "#C8E6C9", "Medium": "#FFF9C4",
                 "Hard": "#FFCCBC", "Extremely Hard": "#E1BEE7"}
    x_positions = np.arange(len(question_ids))

    # 根据是否有重复率数据决定子图数量
    if repetition_data:
        # 创建三个子图：准确率、长度、重复率
        fig, (ax1, ax2, ax3) = plt.subplots(
            3, 1, figsize=(24, 18), sharex=True)
        axes = [ax1, ax2, ax3]
    else:
        # 创建两个子图：准确率、长度
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(24, 14), sharex=True)
        axes = [ax1, ax2]

    # 绘制准确率子图
    for difficulty in DIFFICULTY_ORDER:
        if difficulty in difficulty_boundaries:
            start_idx, end_idx = difficulty_boundaries[difficulty]
            ax1.axvspan(start_idx - 0.5, end_idx + 0.5,
                        alpha=0.3, color=bg_colors[difficulty])
            mid_pos = (start_idx + end_idx) / 2
            ax1.text(mid_pos, 1.10, difficulty, ha='center', va='center', fontsize=20, fontweight='bold',
                     bbox=dict(boxstyle='round,pad=0.5', facecolor='white', edgecolor='gray', alpha=0.9, linewidth=2))

    for idx, (model_name, sorted_questions) in enumerate(accuracy_data.items()):
        # 使用 question_indices (位置索引) 来匹配数据
        values = [next((q[1] for q in sorted_questions if q[0] == qidx), 0)
                  for qidx in question_indices]
        style = resolve_series_style(model_name, idx)
        ax1.plot(x_positions, values,
                 label=model_name, zorder=3, **style)

    ax1.set_ylabel('Accuracy', fontsize=22, fontweight='bold', labelpad=12)
    # if repetition_data:
    #     ax1.set_title('Model Comparison: Accuracy, Response Length, and Repetition Rate by Difficulty',
    #                  fontsize=26, fontweight='bold', pad=40)
    # else:
    #     ax1.set_title('Model Comparison: Accuracy and Response Length by Difficulty',
    #                  fontsize=26, fontweight='bold', pad=40)
    ax1.set_ylim(-0.05, 1.18)
    ax1.set_yticks([0.0, 0.25, 0.50, 0.75, 1.00])
    ax1.tick_params(axis='y', labelsize=18)
    ax1.grid(True, alpha=0.3, linestyle='--', zorder=1, linewidth=1.2)
    ax1.set_facecolor('#FAFAFA')

    # 绘制平均长度子图
    max_length = max(q[1] for sorted_q in length_data.values()
                     for q in sorted_q)

    for difficulty in DIFFICULTY_ORDER:
        if difficulty in difficulty_boundaries:
            start_idx, end_idx = difficulty_boundaries[difficulty]
            ax2.axvspan(start_idx - 0.5, end_idx + 0.5,
                        alpha=0.3, color=bg_colors[difficulty])

    for idx, (model_name, sorted_questions) in enumerate(length_data.items()):
        # 使用 question_indices (位置索引) 来匹配数据
        values = [next((q[1] for q in sorted_questions if q[0] == qidx), 0)
                  for qidx in question_indices]
        style = resolve_series_style(model_name, idx)
        ax2.plot(x_positions, values,
                 label=model_name, zorder=3, **style)

    ax2.set_ylabel('Response Length', fontsize=22,
                   fontweight='bold', labelpad=12)
    ax2.set_ylim(0, max_length * 1.15)

    # 自定义 y 轴格式化函数，以 k 为单位显示
    def format_k(x, pos):
        if x >= 1000:
            return f'{x/1000:.1f}k'
        else:
            return f'{x:.0f}'

    ax2.yaxis.set_major_formatter(ticker.FuncFormatter(format_k))
    ax2.tick_params(axis='y', labelsize=18)
    ax2.grid(True, alpha=0.3, linestyle='--', zorder=1, linewidth=1.2)
    ax2.set_facecolor('#FAFAFA')

    # 如果有重复率数据，绘制重复率子图
    if repetition_data:
        for difficulty in DIFFICULTY_ORDER:
            if difficulty in difficulty_boundaries:
                start_idx, end_idx = difficulty_boundaries[difficulty]
                ax3.axvspan(start_idx - 0.5, end_idx + 0.5,
                            alpha=0.3, color=bg_colors[difficulty])

        for idx, (model_name, sorted_questions) in enumerate(repetition_data.items()):
            # 使用 question_indices (位置索引) 来匹配数据
            values = [next((q[1] for q in sorted_questions if q[0] == qidx), 0)
                      for qidx in question_indices]
            style = resolve_series_style(model_name, idx)
            ax3.plot(x_positions, values,
                     label=model_name, zorder=3, **style)

        ax3.set_xlabel('Sorted AIME24 Question IDs',
                       fontsize=22, fontweight='bold', labelpad=12)
        ax3.set_ylabel('Repetition Rate', fontsize=22,
                       fontweight='bold', labelpad=12)
        ax3.set_ylim(-0.05, 1.18)
        ax3.set_yticks([0.0, 0.25, 0.50, 0.75, 1.00])
        ax3.tick_params(axis='y', labelsize=18)
        ax3.grid(True, alpha=0.3, linestyle='--', zorder=1, linewidth=1.2)
        ax3.set_facecolor('#FAFAFA')

        # 设置x轴 - 显示实际的问题ID（来自 aime24_nofigures.jsonl）
        ax3.set_xticks(x_positions)
        ax3.set_xticklabels(question_ids, rotation=45, ha='right', fontsize=16)
    else:
        # 如果没有重复率数据，在长度子图上设置x轴
        ax2.set_xlabel('Sorted AIME24 Question IDs',
                       fontsize=22, fontweight='bold', labelpad=12)
        ax2.set_xticks(x_positions)
        ax2.set_xticklabels(question_ids, rotation=45, ha='right', fontsize=16)

    # 添加难度分界线
    current_pos = 0
    for difficulty in DIFFICULTY_ORDER:
        count = sum(1 for d in difficulties if d == difficulty)
        if count > 0:
            if current_pos > 0:
                for ax in axes:
                    ax.axvline(x=current_pos - 0.5, color='#424242',
                               linestyle='--', linewidth=2.5, alpha=0.7, zorder=2)
            current_pos += count

    add_shared_legend(fig, ax1, len(accuracy_data))
    plt.tight_layout(rect=(0, 0.08, 1, 1))
    fig.patch.set_facecolor('white')

    if output_path:
        plt.savefig(output_path, dpi=300,
                    bbox_inches='tight', facecolor='white')
        print(f"\n图表已保存到: {output_path}")

    else:
        plt.show()

    plt.close()


def plot_repetition_comparison(repetition_data, output_path=None):
    """绘制重复率对比图"""
    if not repetition_data:
        print("没有重复率数据可绘制")
        return

    # 加载 AIME24 参考文件以获取实际的问题ID
    reference_problems, idx_to_id, problem_to_idx, problems_list = load_aime24_reference()

    # 使用第一个模型的数据确定问题ID和难度排序
    first_model = list(repetition_data.keys())[0]
    sorted_questions = repetition_data[first_model]
    question_indices = [q[0] for q in sorted_questions]  # 位置索引 (0-29)
    difficulties = [q[2] for q in sorted_questions]

    # 将位置索引转换为 aime24_nofigures.jsonl 中的实际ID
    question_ids = [idx_to_id[idx] for idx in question_indices]

    # 计算难度区域边界
    difficulty_boundaries = {}
    current_pos = 0
    for difficulty in DIFFICULTY_ORDER:
        count = sum(1 for d in difficulties if d == difficulty)
        if count > 0:
            difficulty_boundaries[difficulty] = (
                current_pos, current_pos + count - 1)
            current_pos += count

    # 配置
    bg_colors = {"Easy": "#C8E6C9", "Medium": "#FFF9C4",
                 "Hard": "#FFCCBC", "Extremely Hard": "#E1BEE7"}
    x_positions = np.arange(len(question_ids))

    # 创建图表
    fig, ax = plt.subplots(1, 1, figsize=(24, 8))

    # 绘制难度背景
    for difficulty in DIFFICULTY_ORDER:
        if difficulty in difficulty_boundaries:
            start_idx, end_idx = difficulty_boundaries[difficulty]
            ax.axvspan(start_idx - 0.5, end_idx + 0.5,
                       alpha=0.3, color=bg_colors[difficulty])
            mid_pos = (start_idx + end_idx) / 2
            ax.text(mid_pos, 1.10, difficulty, ha='center', va='center', fontsize=20, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.5', facecolor='white', edgecolor='gray', alpha=0.9, linewidth=2))

    # 绘制各模型的重复率
    for idx, (model_name, sorted_questions) in enumerate(repetition_data.items()):
        # 使用 question_indices (位置索引) 来匹配数据
        values = [next((q[1] for q in sorted_questions if q[0] == qidx), 0)
                  for qidx in question_indices]
        style = resolve_series_style(model_name, idx)
        ax.plot(x_positions, values,
                label=model_name, zorder=3, **style)

    ax.set_xlabel('Sorted AIME24 Question IDs', fontsize=22,
                  fontweight='bold', labelpad=12)
    ax.set_ylabel('Repetition Rate', fontsize=22,
                  fontweight='bold', labelpad=12)
    ax.set_title('Model Comparison: Repetition Rate by Difficulty',
                 fontsize=26, fontweight='bold', pad=20)
    ax.set_ylim(-0.05, 1.18)
    ax.set_yticks([0.0, 0.25, 0.50, 0.75, 1.00])
    ax.tick_params(axis='y', labelsize=18)
    ax.grid(True, alpha=0.3, linestyle='--', zorder=1, linewidth=1.2)
    ax.set_facecolor('#FAFAFA')

    # 设置x轴 - 显示实际的问题ID（来自 aime24_nofigures.jsonl）
    ax.set_xticks(x_positions)
    ax.set_xticklabels(question_ids, rotation=45, ha='right', fontsize=16)

    # 添加难度分界线
    current_pos = 0
    for difficulty in DIFFICULTY_ORDER:
        count = sum(1 for d in difficulties if d == difficulty)
        if count > 0:
            if current_pos > 0:
                ax.axvline(x=current_pos - 0.5, color='#424242',
                           linestyle='--', linewidth=2.5, alpha=0.7, zorder=2)
            current_pos += count

    add_shared_legend(fig, ax, len(repetition_data))
    plt.tight_layout(rect=(0, 0.08, 1, 1))
    fig.patch.set_facecolor('white')

    if output_path:
        plt.savefig(output_path, dpi=300,
                    bbox_inches='tight', facecolor='white')
        print(f"\n图表已保存到: {output_path}")
    else:
        plt.show()

    plt.close()


def add_shared_legend(fig, ax, num_series):
    """使用与 baseline 图一致的底部共享 legend 样式。"""
    handles, labels = ax.get_legend_handles_labels()
    if not handles:
        return

    legend = fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, -0.02),
        ncol=max(1, num_series),
        frameon=False,
        prop={"size": 13, "weight": "bold"},
    )
    for text in legend.get_texts():
        text.set_color("#141821")


def summarize_metric_by_difficulty(sorted_questions):
    """汇总每个模型在不同难度上的平均指标。"""
    if not sorted_questions:
        return {difficulty: None for difficulty in SUMMARY_ORDER}

    question_metrics = {qid: value for qid, value, _ in sorted_questions}
    summary = {}

    for difficulty in DIFFICULTY_ORDER:
        qids = DIFFICULTY_CLASSES[difficulty]
        values = [question_metrics[qid]
                  for qid in qids if qid in question_metrics]
        summary[difficulty] = float(np.mean(values)) if values else None

    all_values = list(question_metrics.values())
    summary["Overall"] = float(np.mean(all_values)) if all_values else None
    return summary


def format_length_brief(value):
    """将长度格式化为更紧凑的展示形式。"""
    if value is None:
        return "-"
    if abs(value) >= 1000:
        return f"{value / 1000:.1f}k"
    return f"{value:.0f}"


def format_summary_cell(acc_value, length_value):
    """表格单元格格式：acc / len。"""
    acc_str = "-" if acc_value is None else f"{acc_value:.4f}"
    length_str = format_length_brief(length_value)
    return f"{acc_str} / {length_str}"


def build_ascii_table(headers, rows):
    """构建适合终端查看的 ASCII 表格。"""
    str_rows = [[str(cell) for cell in row] for row in rows]
    widths = [len(str(header)) for header in headers]

    for row in str_rows:
        for idx, cell in enumerate(row):
            widths[idx] = max(widths[idx], len(cell))

    separator = "+-" + "-+-".join("-" * width for width in widths) + "-+"
    header_line = "| " + " | ".join(
        str(header).ljust(widths[idx]) for idx, header in enumerate(headers)
    ) + " |"
    row_lines = [
        "| " + " | ".join(cell.ljust(widths[idx])
                          for idx, cell in enumerate(row)) + " |"
        for row in str_rows
    ]

    return "\n".join([separator, header_line, separator, *row_lines, separator])


def build_markdown_table(headers, rows):
    """构建 Markdown 表格。"""
    lines = [
        "| " + " | ".join(headers) + " |",
        "| " + " | ".join(["---"] * len(headers)) + " |"
    ]
    for row in rows:
        lines.append("| " + " | ".join(str(cell) for cell in row) + " |")
    return "\n".join(lines)


def export_summary_tables(models_accuracy_data, models_length_data, markdown_path=None, csv_path=None):
    """导出最终汇总表，便于快速对比不同方法在不同难度上的 acc 和长度。"""
    model_names = []
    for model_name in models_accuracy_data:
        if model_name not in model_names:
            model_names.append(model_name)
    for model_name in models_length_data:
        if model_name not in model_names:
            model_names.append(model_name)

    headers = ["Method"] + [f"{difficulty} (Acc / Len)"
                            for difficulty in SUMMARY_ORDER]
    table_rows = []
    csv_rows = []

    for model_name in model_names:
        accuracy_summary = summarize_metric_by_difficulty(
            models_accuracy_data.get(model_name, []))
        length_summary = summarize_metric_by_difficulty(
            models_length_data.get(model_name, []))

        row = [model_name]
        csv_row = {"Method": model_name}

        for difficulty in SUMMARY_ORDER:
            acc_value = accuracy_summary.get(difficulty)
            length_value = length_summary.get(difficulty)
            row.append(format_summary_cell(acc_value, length_value))
            csv_row[f"{difficulty}_acc"] = "" if acc_value is None else f"{acc_value:.4f}"
            csv_row[f"{difficulty}_len"] = "" if length_value is None else f"{length_value:.1f}"

        table_rows.append(row)
        csv_rows.append(csv_row)

    print("\n" + "=" * 80)
    print("最终汇总表（Acc / Len）")
    print("=" * 80)
    print(build_ascii_table(headers, table_rows))
    print("注：单元格格式为 acc / length，长度使用 tokens，超过 1000 时以 k 表示。")
    print("=" * 80)

    if markdown_path:
        with open(markdown_path, "w", encoding="utf-8") as f:
            f.write("# Final Summary Table\n\n")
            f.write("Cell format: `accuracy / response length`.\n\n")
            f.write(build_markdown_table(headers, table_rows))
            f.write("\n")
        print(f"Markdown 表格已保存到: {markdown_path}")

    if csv_path:
        fieldnames = ["Method"]
        for difficulty in SUMMARY_ORDER:
            fieldnames.extend([f"{difficulty}_acc", f"{difficulty}_len"])

        with open(csv_path, "w", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(csv_rows)
        print(f"CSV 表格已保存到: {csv_path}")


def print_statistics(models_accuracy_data, models_length_data, models_repetition_data=None):
    """打印统计信息"""
    # 准确率统计
    if models_accuracy_data:
        print("\n" + "="*80)
        print("模型准确率对比")
        print("="*80)
        for difficulty in DIFFICULTY_ORDER:
            print(f"\n{difficulty}:")
            qids = DIFFICULTY_CLASSES[difficulty]
            for model_name, sorted_questions in models_accuracy_data.items():
                question_metrics = {q[0]: q[1] for q in sorted_questions}
                values = [question_metrics.get(
                    qid) for qid in qids if qid in question_metrics]
                if values:
                    print(f"  {model_name:<25}: {np.mean(values):.4f}")

        print(f"\n总体:")
        for model_name, sorted_questions in models_accuracy_data.items():
            overall = np.mean([q[1] for q in sorted_questions])
            print(f"  {model_name:<25}: {overall:.4f}")
        print("="*80)

    # 长度统计
    if models_length_data:
        print("\n" + "="*80)
        print("模型平均长度对比")
        print("="*80)
        for difficulty in DIFFICULTY_ORDER:
            print(f"\n{difficulty}:")
            qids = DIFFICULTY_CLASSES[difficulty]
            for model_name, sorted_questions in models_length_data.items():
                question_metrics = {q[0]: q[1] for q in sorted_questions}
                values = [question_metrics.get(
                    qid) for qid in qids if qid in question_metrics]
                if values:
                    print(f"  {model_name:<25}: {np.mean(values):.1f} tokens")

        print(f"\n总体:")
        for model_name, sorted_questions in models_length_data.items():
            overall = np.mean([q[1] for q in sorted_questions])
            print(f"  {model_name:<25}: {overall:.1f} tokens")
        print("="*80)

    # 重复率统计
    if models_repetition_data:
        print("\n" + "="*80)
        print("模型重复率对比")
        print("="*80)
        for difficulty in DIFFICULTY_ORDER:
            print(f"\n{difficulty}:")
            qids = DIFFICULTY_CLASSES[difficulty]
            for model_name, sorted_questions in models_repetition_data.items():
                question_metrics = {q[0]: q[1] for q in sorted_questions}
                values = [question_metrics.get(
                    qid) for qid in qids if qid in question_metrics]
                if values:
                    print(f"  {model_name:<25}: {np.mean(values):.4f}")

        print(f"\n总体:")
        for model_name, sorted_questions in models_repetition_data.items():
            overall = np.mean([q[1] for q in sorted_questions])
            print(f"  {model_name:<25}: {overall:.4f}")
        print("="*80)


def main():
    # 配置模型文件路径
    models_config = [
        #    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/baseline-dapo-math-260steps-valid_32768_test.jsonl", "Baseline"),
        #    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/dapo-math-stage1-380steps-valid_32768_test.jsonl", "Stage1"),
        #    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/dapo-math-stage1-380steps-stage2-wo-resume-grpo-60steps-valid_32768_test.jsonl", "Stage2-60"),
        #  ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_nov/baseline-8k-minibsz32-step610-valid_32768_test.jsonl", "Baseline"),
        #  ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/baseline-8k-stage2-140step-valid_32768_test.jsonl", "Baseline-stage2"),
        #    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/stage1-additive-length-penalty-390step-stage2-grpo-80-valid_32768_test.jsonl", "additive-stage2")
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_nov/stage1-minibsz32-270steps-valid_32768_test.jsonl","original-stage1"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/stage1-additive-length-penalty-390step-valid_32768_test.jsonl", "additive-stage1"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_nov/stage1-minibsz32-270steps-stage2-minibsz32-grpo-110steps-valid_32768_test.jsonl", "original-stage2"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/stage1-additive-length-penalty-390step-stage2-grpo-80-valid_32768_test.jsonl", "additive-stage2")
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/baseline-8k-minibsz32-610step-aime128_16384_test.jsonl","baseline-8k-minibsz32-610step-aime128"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/stage1-additive-length-penalty-390step-aime128_16384_test.jsonl","additive-stage1-aime128")
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-step500-valid-all_32768_test.jsonl", "GSPO"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-skip-right-step600-valid-all_32768_test.jsonl", "GSPO + LIE"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-step500-valid-all_32768_test.jsonl", "GSPO"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-skip-right-step600-valid-all_32768_test.jsonl", "GSPO + LIE"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/gspo_length-valid-690step_32768_test.jsonl", "GSPO + length")
        ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/LIE-redo2_32768_test.jsonl", "redo2"),
        ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/LIE-redo1-step530_32768_test.jsonl", "redo1")
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/qwen3-4b-polaris-baseline-gspo-step550-valid-all_32768_test.jsonl", "GSPO"),
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/qwen3-4b-polaris-add1k-gspo-step660-valid-all_32768_test.jsonl", "GSPO + LIE")
        # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/stage1-skip-right-dapo-math-step400-valid-all_32768_test.jsonl", "Skip-right-400steps")

    ]

    models_accuracy_data = {}
    models_length_data = {}
    models_repetition_data = {}

    for jsonl_file, model_name in models_config:
        print(f"\n{'='*80}\n处理模型: {model_name}\n{'='*80}")

        if not Path(jsonl_file).exists():
            print(f"错误: 文件不存在，跳过")
            continue

        # 处理准确率数据
        print(f"\n--- 计算准确率 ---")
        question_accuracies = process_aime_data(
            jsonl_file, calculate_length=False, calculate_repetition=False)

        # 处理长度数据
        print(f"\n--- 计算平均长度 ---")
        question_lengths = process_aime_data(
            jsonl_file, calculate_length=True, calculate_repetition=False)

        # 处理重复率数据
        print(f"\n--- 计算重复率 ---")
        question_repetitions = process_aime_data(
            jsonl_file, calculate_length=False, calculate_repetition=True)

        if question_accuracies:
            models_accuracy_data[model_name] = sort_questions_by_difficulty(
                question_accuracies)
            print(
                f"{model_name} - 总体准确率: {np.mean(list(question_accuracies.values())):.4f}")

        if question_lengths:
            models_length_data[model_name] = sort_questions_by_difficulty(
                question_lengths)
            print(
                f"{model_name} - 平均长度: {np.mean(list(question_lengths.values())):.1f} tokens")

        if question_repetitions:
            models_repetition_data[model_name] = sort_questions_by_difficulty(
                question_repetitions)
            print(
                f"{model_name} - 总体重复率: {np.mean(list(question_repetitions.values())):.4f}")

    if not models_accuracy_data:
        print("\n错误: 没有成功处理任何模型数据")
        return

    # 打印详细统计
    print_statistics(models_accuracy_data, models_length_data,
                     models_repetition_data)

    output_dir = Path(
        "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/plots"
    )
    output_dir.mkdir(parents=True, exist_ok=True)

    print("\n生成最终汇总表（不同方法 x 不同难度的 acc 和长度）...")
    export_summary_tables(
        models_accuracy_data,
        models_length_data,
        output_dir / "aime-gspo_summary_table.md",
        output_dir / "aime-gspo_summary_table.csv",
    )

    # 绘制组合图表（准确率、长度和重复率）
    output_path = output_dir / "aime-gspo_combined_comparison.png"
    print("\n生成组合对比图（包含准确率、长度和重复率）...")
    plot_combined_comparison(
        models_accuracy_data, models_length_data,
        models_repetition_data, str(output_path))

    print("\n处理完成！")


if __name__ == "__main__":
    main()
