import json
import sys
import argparse
import os
import glob
import re
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import as_completed
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import jensenshannon
from scipy.stats import entropy
from collections import Counter
try:
    from transformers import AutoTokenizer
except ImportError:
    AutoTokenizer = None
try:
    from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
except ImportError:
    SmoothingFunction = None
    sentence_bleu = None


_WORKER_TOKENIZER = None
_WORKER_TOKENIZER_PATH = None

# 只分析这些 data_source 的样本（大小写不敏感）
# 用户写的是 "olympaid"，这里也兼容常见拼写 "olympiad"（避免数据为空）
# ALLOWED_DATA_SOURCES = {"aime","olympiad_bench", "amc", "math"}


def detect_repetition_with_hash(text, window_size=10, return_text=False, max_repetitions_limit=6):
    """
    Use hashing to efficiently detect repeated n-grams (split by space and underscore).
    统计重复次数>=max_repetitions_limit的窗口数量，并归一化（除以总窗口数）

    Returns:
        如果return_text=False: (repetition_ratio, repetition_count)
        如果return_text=True: (repetition_ratio, repetition_count, repeated_text)

        repetition_ratio: 重复窗口数/总窗口数（归一化后的比例）
        repetition_count: 重复窗口的数量
        repeated_text: 重复次数最多的文本片段
    """
    # Split text by both space and underscore
    words = []
    for segment in text.split():
        words.extend(segment.split('_'))

    if len(words) <= window_size:
        if return_text:
            return 0.0, 0, ""
        return 0.0, 0

    hash_counts = {}
    hash_to_window = {}  # 记录每个hash对应的窗口文本
    max_repetitions = 0
    max_repeated_window = None

    # 第一遍：统计所有窗口的出现次数
    total_windows = len(words) - window_size + 1

    for i in range(total_windows):
        # Get window and its hash
        window = tuple(words[i:i+window_size])
        window_hash = hash(window)

        # 第一次遇到这个窗口时，记录它
        if window_hash not in hash_to_window:
            hash_to_window[window_hash] = window

        # Update count for this hash
        hash_counts[window_hash] = hash_counts.get(window_hash, 0) + 1

        # Update max repetitions
        if hash_counts[window_hash] > max_repetitions:
            max_repetitions = hash_counts[window_hash]
            max_repeated_window = window

    # 第二遍：统计重复次数>=max_repetitions_limit的窗口数量
    repetition_count = 0
    for i in range(total_windows):
        window = tuple(words[i:i+window_size])
        window_hash = hash(window)
        if hash_counts[window_hash] >= max_repetitions_limit:
            repetition_count += 1

    # 归一化：重复窗口数 / 总窗口数
    repetition_ratio = repetition_count / total_windows if total_windows > 0 else 0.0

    if return_text:
        repeated_text = " ".join(
            max_repeated_window) if max_repeated_window else ""
        return repetition_ratio, repetition_count, repeated_text

    return repetition_ratio, repetition_count


def compute_distinct_ngrams(text, n=3):
    """
    计算文本中 distinct n-gram 的数量和比例，用于衡量 exploration（探索性）
    使用类似 detect_repetition_with_hash 的方法

    Args:
        text: 输入文本
        n: n-gram 的大小（例如 3 表示 3-gram）

    Returns:
        (distinct_count, total_count, distinct_ratio)
        distinct_count: distinct n-gram 的数量
        total_count: 总 n-gram 数量
        distinct_ratio: distinct_count / total_count（归一化后的比例）
    """
    # Split text by both space and underscore
    words = []
    for segment in text.split():
        words.extend(segment.split('_'))

    if len(words) < n:
        return 1, 1, 1.0

    # 使用 set 来存储 distinct n-grams（使用 hash）
    distinct_ngrams = set()
    total_ngrams = len(words) - n + 1

    for i in range(total_ngrams):
        ngram = tuple(words[i:i+n])
        ngram_hash = hash(ngram)
        distinct_ngrams.add(ngram_hash)

    distinct_count = len(distinct_ngrams)
    distinct_ratio = distinct_count / total_ngrams if total_ngrams > 0 else 0.0

    return distinct_count, total_ngrams, distinct_ratio


def compute_corpus_ngram_hashes(texts, n=3):
    """
    计算一组文本(一个 step 或者全局所有 step)的 n-gram 统计，用于"全局维度"的 distinct ratio。
    这里的 distinct 以 n-gram 的 hash 表示（与 compute_distinct_ngrams 保持一致）。

    Args:
        texts: 文本列表
        n: n-gram 大小

    Returns:
        (distinct_hashes, total_ngrams)
        distinct_hashes: set，语料中出现过的 distinct n-gram(hash) 集合
        total_ngrams: int，语料中 n-gram 的总数量(含重复)
    """
    distinct_hashes = set()
    total_ngrams = 0
    if not texts:
        return distinct_hashes, 0

    for text in texts:
        ngrams = extract_ngrams_from_text(text, n)
        total_ngrams += len(ngrams)
        for ng in ngrams:
            distinct_hashes.add(hash(ng))

    return distinct_hashes, total_ngrams


def compute_corpus_ngram_texts(texts, n=3):
    """
    计算一组文本的 distinct n-gram 文本集合（用于保存到文件）

    Args:
        texts: 文本列表
        n: n-gram 大小

    Returns:
        (distinct_ngrams, total_ngrams)
        distinct_ngrams: set，语料中出现过的 distinct n-gram 文本集合（每个 n-gram 是 tuple）
        total_ngrams: int，语料中 n-gram 的总数量(含重复)
    """
    distinct_ngrams = set()
    total_ngrams = 0
    if not texts:
        return distinct_ngrams, 0

    for text in texts:
        ngrams = extract_ngrams_from_text(text, n)
        total_ngrams += len(ngrams)
        for ng in ngrams:
            distinct_ngrams.add(ng)

    return distinct_ngrams, total_ngrams


def extract_ngrams_from_text(text, n=3):
    """
    从文本中提取所有 n-grams

    Args:
        text: 输入文本
        n: n-gram 的大小

    Returns:
        list of tuples: 所有 n-grams 的列表
    """
    words = []
    for segment in text.split():
        words.extend(segment.split('_'))

    if len(words) < n:
        return []

    ngrams = []
    for i in range(len(words) - n + 1):
        ngram = tuple(words[i:i+n])
        ngrams.append(ngram)

    return ngrams


def compute_ngram_distribution(texts, n=3):
    """
    计算文本集合的 n-gram 分布

    Args:
        texts: 文本列表
        n: n-gram 的大小

    Returns:
        Counter: n-gram 的频率分布
    """
    all_ngrams = []
    for text in texts:
        ngrams = extract_ngrams_from_text(text, n)
        all_ngrams.extend(ngrams)

    return Counter(all_ngrams)


def compute_distribution_entropy(dist):
    """
    计算单个分布的熵

    Args:
        dist: Counter 对象，频率分布

    Returns:
        float: 熵值（使用自然对数）
    """
    if len(dist) == 0:
        return 0.0

    total = sum(dist.values())
    if total == 0:
        return 0.0

    # 构建概率向量
    probabilities = np.array([count / total for count in dist.values()])

    # 计算熵（使用 scipy.stats.entropy，默认使用自然对数）
    # 熵越高，表示分布越均匀，越不确定
    return entropy(probabilities)


def compute_distribution_divergence(dist1, dist2, method='js', return_entropies=False):
    """
    计算两个分布之间的散度

    Args:
        dist1: Counter 对象，第一个分布
        dist2: Counter 对象，第二个分布
        method: 'js' for Jensen-Shannon divergence, 'kl' for KL divergence
        return_entropies: 是否同时返回两个分布的熵

    Returns:
        如果 return_entropies=False: float 散度值
        如果 return_entropies=True: tuple (散度值, 熵1, 熵2)
    """
    # 获取所有可能的 n-grams
    all_ngrams = set(dist1.keys()) | set(dist2.keys())

    if len(all_ngrams) == 0:
        if return_entropies:
            return 0.0, 0.0, 0.0
        return 0.0

    # 创建概率分布向量
    total1 = sum(dist1.values())
    total2 = sum(dist2.values())

    if total1 == 0 or total2 == 0:
        if return_entropies:
            return 0.0, 0.0, 0.0
        return 0.0

    # 构建对齐的概率向量
    p = np.array([dist1.get(ng, 0) / total1 for ng in all_ngrams])
    q = np.array([dist2.get(ng, 0) / total2 for ng in all_ngrams])

    # 计算各自的熵（在平滑之前，使用原始概率）
    entropy1 = entropy(p[p > 0])  # 只计算非零概率的熵
    entropy2 = entropy(q[q > 0])

    # 添加平滑以避免除零
    epsilon = 1e-10
    p = p + epsilon
    q = q + epsilon
    p = p / p.sum()
    q = q / q.sum()

    if method == 'js':
        # Jensen-Shannon divergence (取值范围 [0, 1])
        divergence = jensenshannon(p, q)
    elif method == 'kl':
        # KL divergence (可能为无穷大)
        divergence = entropy(p, q)
    else:
        raise ValueError(f"Unknown method: {method}")

    if return_entropies:
        return divergence, entropy1, entropy2
    return divergence


def compare_sample_distributions(responses1, responses2, n=3, method='js', return_entropies=False):
    """
    比较两个样本集合的 n-gram 分布差异

    Args:
        responses1: 第一个样本集合（文本列表）
        responses2: 第二个样本集合（文本列表）
        n: n-gram 的大小
        method: 散度计算方法
        return_entropies: 是否同时返回两个分布的熵

    Returns:
        如果 return_entropies=False: float 散度值
        如果 return_entropies=True: tuple (散度值, 熵1, 熵2)
    """
    dist1 = compute_ngram_distribution(responses1, n)
    dist2 = compute_ngram_distribution(responses2, n)

    if return_entropies:
        divergence, entropy1, entropy2 = compute_distribution_divergence(
            dist1, dist2, method, return_entropies=True)
        return divergence, entropy1, entropy2
    else:
        divergence = compute_distribution_divergence(
            dist1, dist2, method, return_entropies=False)
        return divergence


def _normalize_formula_text(formula):
    """规范化公式文本，减少换行和空白差异带来的重复计数偏差。"""
    return re.sub(r'\s+', ' ', formula).strip()


def extract_formulas(text, unique=False):
    """
    从单条 response 中提取公式。

    这里保留列表顺序，便于统计“总公式数”和“去重后的公式数”。
    """
    if not text:
        return []

    patterns = [
        r'\\\[(.*?)\\\]',
        r'\\\((.*?)\\\)',
        r'\$\$(.*?)\$\$',
        r'(?<!\$)\$(?!\$)(.*?)(?<!\$)\$(?!\$)',
    ]

    formulas = []
    for pattern in patterns:
        for match in re.findall(pattern, text, flags=re.DOTALL):
            normalized = _normalize_formula_text(match)
            if normalized:
                formulas.append(normalized)

    if not unique:
        return formulas

    deduped = []
    seen = set()
    for formula in formulas:
        if formula not in seen:
            deduped.append(formula)
            seen.add(formula)
    return deduped


def _split_text_into_internal_units(text, div_len=None, min_chunk_tokens=12, max_chunks=4):
    """
    将单条 response 切成若干内部单元，用于“句内/response 内”文本多样性计算。

    优先按换行和句末标点切；如果切不出多个单元，则退化为按 token 均分。
    """
    if not text:
        return []

    truncated = text.strip() if div_len is None else text[:div_len].strip()
    if not truncated:
        return []

    raw_units = re.split(r'(?:\n+|(?<=[。！？.!?])\s+)', truncated)
    units = [re.sub(r'\s+', ' ', unit).strip()
             for unit in raw_units if unit and unit.strip()]
    units = [unit for unit in units if len(unit.split()) >= 2]

    if len(units) >= 2:
        return units

    tokens = truncated.split()
    if len(tokens) < 2:
        return [truncated]

    if len(tokens) < min_chunk_tokens * 2:
        midpoint = max(1, len(tokens) // 2)
        return [
            " ".join(tokens[:midpoint]).strip(),
            " ".join(tokens[midpoint:]).strip()
        ]

    chunk_count = min(max_chunks, max(
        2, int(np.ceil(len(tokens) / min_chunk_tokens))))
    chunk_size = int(np.ceil(len(tokens) / chunk_count))
    token_units = []
    for i in range(0, len(tokens), chunk_size):
        chunk = " ".join(tokens[i:i + chunk_size]).strip()
        if chunk:
            token_units.append(chunk)
    return token_units


def _fallback_sentence_bleu(reference_tokens, candidate_tokens, weights):
    """在 nltk 不可用时的轻量级 BLEU 近似实现。"""
    if not reference_tokens or not candidate_tokens:
        return 0.0

    def _ngram_counter(tokens, n):
        if len(tokens) < n:
            return Counter()
        return Counter(tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1))

    weighted_log_precision = 0.0
    effective_weight_sum = 0.0
    for n, weight in enumerate(weights, start=1):
        if weight <= 0:
            continue
        cand_counts = _ngram_counter(candidate_tokens, n)
        ref_counts = _ngram_counter(reference_tokens, n)
        if not cand_counts:
            continue

        overlap = 0
        for ngram, cand_count in cand_counts.items():
            overlap += min(cand_count, ref_counts.get(ngram, 0))

        # 使用一个很小的平滑，避免长文本在高阶 n-gram 上全部归零。
        precision = (overlap + 1e-9) / (sum(cand_counts.values()) + 1e-9)
        weighted_log_precision += weight * np.log(max(precision, 1e-12))
        effective_weight_sum += weight

    if effective_weight_sum <= 0:
        return 0.0

    weighted_log_precision /= effective_weight_sum
    ref_len = len(reference_tokens)
    cand_len = len(candidate_tokens)
    brevity_penalty = 1.0 if cand_len > ref_len else np.exp(
        1 - ref_len / max(cand_len, 1))
    return float(brevity_penalty * np.exp(weighted_log_precision))


def _compute_bleu_similarity(reference_text, candidate_text, weights):
    reference_tokens = reference_text.split()
    candidate_tokens = candidate_text.split()
    if not reference_tokens or not candidate_tokens:
        return 0.0

    weight_sum = sum(weights)
    if weight_sum <= 0:
        return 0.0
    normalized_weights = tuple(weight / weight_sum for weight in weights)

    if sentence_bleu is not None and SmoothingFunction is not None:
        smoothing = SmoothingFunction().method1
        return float(
            sentence_bleu(
                [reference_tokens],
                candidate_tokens,
                weights=normalized_weights,
                smoothing_function=smoothing,
            )
        )

    return _fallback_sentence_bleu(reference_tokens, candidate_tokens, normalized_weights)


def calculate_internal_textual_diversity(text, div_len=None, weights=(0.02, 0.1, 0.15, 0.25, 0.38)):
    """
    将 group-level TD 改写为单条 response 内部的文本多样性。

    做法：
    1. 将 response 切成多个内部单元；
    2. 计算单元之间的平均 BLEU 相似度；
    3. 用 1 - avg_similarity 表示“句内文本多样性”。
    """
    units = _split_text_into_internal_units(text, div_len=div_len)
    if len(units) < 2:
        return 0.0, 0.0, len(units)

    pairwise_dissimilarities = []
    pairwise_similarities = []

    for i in range(len(units)):
        for j in range(i + 1, len(units)):
            bleu_i_j = _compute_bleu_similarity(units[i], units[j], weights)
            bleu_j_i = _compute_bleu_similarity(units[j], units[i], weights)
            similarity = (bleu_i_j + bleu_j_i) / 2.0
            pairwise_similarities.append(similarity)
            pairwise_dissimilarities.append(1.0 - similarity)

    if not pairwise_dissimilarities:
        return 0.0, 0.0, len(units)

    return (
        float(np.mean(pairwise_dissimilarities)),
        float(np.mean(pairwise_similarities)),
        len(units),
    )


def calculate_internal_equational_diversity(text, div_len=None):
    """
    将 group-level ED 改写为单条 response 内部的公式多样性。

    定义为：
        句内唯一公式数 / 句内公式总数
    若没有公式，则返回 0。
    """
    target_text = text if div_len is None else text[:div_len]
    formulas = extract_formulas(target_text, unique=False)
    if not formulas:
        return 0.0, 0, 0

    unique_formula_count = len(set(formulas))
    total_formula_count = len(formulas)
    diversity = unique_formula_count / total_formula_count
    return float(diversity), unique_formula_count, total_formula_count


def save_response_metrics(response_details, output_path):
    """将逐条 response 的明细指标保存为 jsonl。"""
    if not response_details:
        print(f"警告: 没有 response 明细可保存: {output_path}")
        return

    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    with open(output_path, 'w', encoding='utf-8') as f:
        for item in response_details:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    print(f"已保存 {len(response_details)} 条 response 明细到: {output_path}")


def analyze_repetition(
    file_path,
    window_size=10,
    max_repetitions_limit=6,
    verbose=True,
    show_examples=False,
    compute_ngrams=False,
    ngram_sizes=[3, 5],
    tokenizer=None,
    compute_sentence_diversity=False,
    sentence_diversity_max_len=None,
    return_response_details=False,
):
    """
    分析jsonl文件中response的重复情况，并计算 Accuracy (Macro Avg)。
    仅分析 data_source 属于 ALLOWED_DATA_SOURCES 的样本。
    """
    responses = []
    scores = []  # 用于计算 macro average accuracy

    if verbose:
        print(f"正在读取文件: {file_path}")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())

                    # 只保留指定 data_source 的样本
                    # source = str(data.get('data_source', '')).strip().lower()
                    # if source not in ALLOWED_DATA_SOURCES:
                    #     continue

                    # 获取 response
                    resp = data.get('output') or data.get('response')
                    if resp is None:
                        continue
                    responses.append(resp)

                    # 获取 accurcies 用于计算 accuracy（列表取平均）
                    # breakpoint()
                    accurcies = data.get('score')
                    scores.append(accurcies)

                except json.JSONDecodeError as e:
                    if verbose:
                        print(f"解析JSON失败: {e}")
                    continue
    except FileNotFoundError:
        print(f"错误: 文件未找到: {file_path}")
        return None

    if verbose:
        print(f"  总共读取了 {len(responses)} 个response")

    # 对每个response检测重复
    repetition_ratio_list = []  # 存储每个response的重复比例
    repetition_examples = []  # 存储重复文本示例
    token_lengths = []  # 存储每个response的token长度
    response_details = []  # 逐条 response 的明细指标
    internal_textual_diversity_list = []
    internal_textual_similarity_list = []
    internal_unit_count_list = []
    internal_equational_diversity_list = []
    formula_unique_count_list = []
    formula_total_count_list = []

    # 批量计算 token 长度以加快速度
    if tokenizer and responses:
        if verbose:
            print(f"  正在批量计算 {len(responses)} 个 response 的 token 长度...")
        # 批量编码，不进行 padding 和 truncation，只为了获取长度
        batch_encodings = tokenizer(
            responses, add_special_tokens=False, padding=False, truncation=False)
        token_lengths = [len(ids) for ids in batch_encodings['input_ids']]

    # n-gram 分析相关
    ngram_distinct_counts = {n: [] for n in ngram_sizes}
    ngram_distinct_ratios = {n: [] for n in ngram_sizes}

    for i, response in enumerate(responses):
        if show_examples:
            repetition_ratio, repetition_count, repeated_text = detect_repetition_with_hash(
                response,
                window_size=window_size,
                return_text=True,
                max_repetitions_limit=max_repetitions_limit
            )
            if repetition_ratio > 0.05:  # 只记录重复比例 > 5% 的
                repetition_examples.append({
                    'index': i,
                    'ratio': repetition_ratio,
                    'count': repetition_count,
                    'repeated_text': repeated_text,
                    'response_preview': response[:300] + '...' if len(response) > 300 else response
                })
        else:
            repetition_ratio, repetition_count = detect_repetition_with_hash(
                response,
                window_size=window_size,
                max_repetitions_limit=max_repetitions_limit
            )
        repetition_ratio_list.append(repetition_ratio)

        # 计算 distinct n-grams（如果启用）
        if compute_ngrams:
            for n in ngram_sizes:
                distinct_count, _, distinct_ratio = compute_distinct_ngrams(
                    response, n=n)
                ngram_distinct_counts[n].append(distinct_count)
                ngram_distinct_ratios[n].append(distinct_ratio)

        current_internal_textual_diversity = 0.0
        current_internal_textual_similarity = 0.0
        current_internal_unit_count = 0
        current_internal_equational_diversity = 0.0
        current_formula_unique_count = 0
        current_formula_total_count = 0

        if compute_sentence_diversity:
            current_internal_textual_diversity, current_internal_textual_similarity, current_internal_unit_count = calculate_internal_textual_diversity(
                response,
                div_len=sentence_diversity_max_len,
            )
            current_internal_equational_diversity, current_formula_unique_count, current_formula_total_count = calculate_internal_equational_diversity(
                response,
                div_len=sentence_diversity_max_len,
            )

            internal_textual_diversity_list.append(
                current_internal_textual_diversity)
            internal_textual_similarity_list.append(
                current_internal_textual_similarity)
            internal_unit_count_list.append(current_internal_unit_count)
            internal_equational_diversity_list.append(
                current_internal_equational_diversity)
            formula_unique_count_list.append(current_formula_unique_count)
            formula_total_count_list.append(current_formula_total_count)

        if return_response_details:
            detail = {
                'index': i,
                'score': scores[i] if i < len(scores) else None,
                'response': response,
                'repetition_ratio': repetition_ratio,
                'response_preview': response[:300] + '...' if len(response) > 300 else response,
            }
            if token_lengths:
                detail['token_length'] = token_lengths[i]
            if compute_sentence_diversity:
                detail.update({
                    'internal_textual_diversity': current_internal_textual_diversity,
                    'internal_textual_similarity': current_internal_textual_similarity,
                    'internal_unit_count': current_internal_unit_count,
                    'internal_equational_diversity': current_internal_equational_diversity,
                    'formula_unique_count': current_formula_unique_count,
                    'formula_total_count': current_formula_total_count,
                })
            if compute_ngrams:
                for n in ngram_sizes:
                    if i < len(ngram_distinct_counts[n]):
                        detail[f'distinct_{n}gram_count'] = ngram_distinct_counts[n][i]
                    if i < len(ngram_distinct_ratios[n]):
                        detail[f'distinct_{n}gram_ratio'] = ngram_distinct_ratios[n][i]
            response_details.append(detail)

    # 计算统计信息
    total_responses = len(repetition_ratio_list)
    avg_repetition_ratio = np.mean(
        repetition_ratio_list) if total_responses > 0 else 0.0
    avg_token_length = np.mean(token_lengths) if token_lengths else 0.0
    avg_internal_textual_diversity = np.mean(
        internal_textual_diversity_list) if internal_textual_diversity_list else 0.0
    avg_internal_textual_similarity = np.mean(
        internal_textual_similarity_list) if internal_textual_similarity_list else 0.0
    avg_internal_unit_count = np.mean(
        internal_unit_count_list) if internal_unit_count_list else 0.0
    avg_internal_equational_diversity = np.mean(
        internal_equational_diversity_list) if internal_equational_diversity_list else 0.0
    avg_formula_unique_count = np.mean(
        formula_unique_count_list) if formula_unique_count_list else 0.0
    avg_formula_total_count = np.mean(
        formula_total_count_list) if formula_total_count_list else 0.0

    # 计算 Macro Average Accuracy
    valid_scores = [score for score in scores if score is not None]
    macro_acc = np.mean(valid_scores) if valid_scores else 0.0

    if verbose:
        print(
            f"  平均重复比例: {avg_repetition_ratio:.4f} ({avg_repetition_ratio*100:.2f}%)")
        if tokenizer:
            print(f"  平均Token长度: {avg_token_length:.2f}")
        if compute_sentence_diversity:
            print(
                f"  平均句内文本多样性 (intra-TD): {avg_internal_textual_diversity:.4f}")
            print(
                f"  平均句内文本相似度 (intra-BLEU): {avg_internal_textual_similarity:.4f}")
            print(f"  平均句内切分单元数: {avg_internal_unit_count:.2f}")
            print(
                f"  平均句内公式多样性 (intra-ED): {avg_internal_equational_diversity:.4f}")
            print(
                f"  平均句内唯一公式数/总公式数: {avg_formula_unique_count:.2f}/{avg_formula_total_count:.2f}")
        print(f"  Accuracy (Macro Avg): {macro_acc:.4f}")

    # 计算 n-gram 统计信息
    avg_ngram_distinct_counts = {}
    avg_ngram_distinct_ratios = {}
    if compute_ngrams:
        for n in ngram_sizes:
            avg_count = np.mean(
                ngram_distinct_counts[n]) if ngram_distinct_counts[n] else 0.0
            avg_ratio = np.mean(
                ngram_distinct_ratios[n]) if ngram_distinct_ratios[n] else 0.0
            avg_ngram_distinct_counts[n] = avg_count
            avg_ngram_distinct_ratios[n] = avg_ratio
            if verbose:
                print(
                    f"  平均{n}-gram distinct个数: {avg_count:.2f}, 比例: {avg_ratio:.4f}")

    # 显示重复文本示例
    if show_examples and repetition_examples:
        print(f"\n  重复文本示例 (前5个，按重复比例降序):")
        print("  " + "-"*60)
        # 按重复比例排序
        repetition_examples.sort(key=lambda x: x['ratio'], reverse=True)
        for example in repetition_examples[:5]:
            print(
                f"  样例 #{example['index']}: 重复比例={example['ratio']:.2%}, 重复窗口数={example['count']}")
            print(f"  重复片段: \"{example['repeated_text']}\"")
            print("  " + "-"*60)

    result = {
        'total_responses': total_responses,
        'repetition_rate': avg_repetition_ratio,  # 保持兼容性，这里是平均重复比例
        'repetition_examples': repetition_examples if show_examples else [],
        'avg_token_length': avg_token_length,
        'accuracy': macro_acc
    }

    if compute_ngrams:
        for n, avg_count in avg_ngram_distinct_counts.items():
            result[f'distinct_{n}gram_count'] = avg_count
            result[f'distinct_{n}gram_ratio'] = avg_ngram_distinct_ratios[n]

    if compute_sentence_diversity:
        result['internal_textual_diversity'] = avg_internal_textual_diversity
        result['internal_textual_similarity'] = avg_internal_textual_similarity
        result['internal_unit_count'] = avg_internal_unit_count
        result['internal_equational_diversity'] = avg_internal_equational_diversity
        result['formula_unique_count'] = avg_formula_unique_count
        result['formula_total_count'] = avg_formula_total_count

    if return_response_details:
        result['response_details'] = response_details

    return result


def read_responses_from_file(file_path):
    """
    从 jsonl 文件读取所有 responses

    Args:
        file_path: 文件路径

    Returns:
        list: response 列表
    """
    responses = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())
                    # source = str(data.get('data_source', '')).strip().lower()
                    # if source not in ALLOWED_DATA_SOURCES:
                    #     continue
                    if 'output' in data:
                        responses.append(data['output'])
                    elif 'response' in data:
                        responses.append(data['response'])
                except (json.JSONDecodeError, KeyError):
                    continue
    except FileNotFoundError:
        return []
    return responses


def _collect_step_files(directory, pattern="*_16384.jsonl", verbose=False):
    file_pattern = os.path.join(directory, pattern)
    files = glob.glob(file_pattern)

    step_files = []
    for file_path in files:
        filename = os.path.basename(file_path)
        try:
            step = int(filename.split('_')[0])
        except ValueError:
            if verbose:
                print(f"警告: 无法从文件名 {filename} 提取step编号，跳过")
            continue
        step_files.append((step, file_path, filename))

    step_files.sort(key=lambda item: item[0])
    return step_files


def _get_response_metrics_path(response_metrics_dir, step, filename):
    if not response_metrics_dir:
        return None
    return os.path.join(
        response_metrics_dir,
        f"{step}_{os.path.splitext(filename)[0]}_response_metrics.jsonl"
    )


def _get_cached_tokenizer(tokenizer_path):
    if not tokenizer_path:
        return None
    if AutoTokenizer is None:
        raise ImportError(
            "transformers is required when tokenizer_path is provided")

    global _WORKER_TOKENIZER
    global _WORKER_TOKENIZER_PATH
    if _WORKER_TOKENIZER is None or _WORKER_TOKENIZER_PATH != tokenizer_path:
        _WORKER_TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_path)
        _WORKER_TOKENIZER_PATH = tokenizer_path
    return _WORKER_TOKENIZER


def _compute_token_lengths(texts, tokenizer, batch_size=256):
    if not tokenizer or not texts:
        return []

    lengths = []
    for start in range(0, len(texts), batch_size):
        batch = texts[start:start + batch_size]
        batch_encodings = tokenizer(
            batch,
            add_special_tokens=False,
            padding=False,
            truncation=False,
        )
        lengths.extend(len(ids) for ids in batch_encodings['input_ids'])
    return lengths


def _apply_token_lengths_after_parallel(step_files, step_results, tokenizer, response_metrics_dir=None):
    if not tokenizer or not step_results:
        return

    step_to_result = {int(item['step']): item for item in step_results}
    for step, file_path, filename in step_files:
        responses = read_responses_from_file(file_path)
        token_lengths = _compute_token_lengths(responses, tokenizer)
        avg_token_length = float(
            np.mean(token_lengths)) if token_lengths else 0.0

        if step in step_to_result:
            step_to_result[step]['avg_token_length'] = avg_token_length

        metrics_path = _get_response_metrics_path(
            response_metrics_dir, step, filename)
        if not metrics_path or not os.path.exists(metrics_path):
            continue

        try:
            with open(metrics_path, 'r', encoding='utf-8') as handle:
                rows = [json.loads(line) for line in handle if line.strip()]
        except Exception:
            continue

        if len(rows) != len(token_lengths):
            continue

        for row, token_length in zip(rows, token_lengths):
            row['token_length'] = token_length

        with open(metrics_path, 'w', encoding='utf-8') as handle:
            for row in rows:
                handle.write(json.dumps(row, ensure_ascii=False) + '\n')


def _build_step_result(
    step,
    filename,
    result,
    compute_sentence_diversity,
    compute_ngrams,
    ngram_sizes,
    responses_by_step,
    global_distinct_counts,
    global_total_ngram_counts,
):
    step_result = {
        'step': step,
        'file': filename,
        'repetition_rate': result['repetition_rate'],
        'total_responses': result['total_responses'],
        'avg_token_length': result['avg_token_length'],
        'accuracy': result['accuracy']
    }
    if compute_sentence_diversity:
        step_result['internal_textual_diversity'] = result.get(
            'internal_textual_diversity', 0.0)
        step_result['internal_textual_similarity'] = result.get(
            'internal_textual_similarity', 0.0)
        step_result['internal_unit_count'] = result.get(
            'internal_unit_count', 0.0)
        step_result['internal_equational_diversity'] = result.get(
            'internal_equational_diversity', 0.0)
        step_result['formula_unique_count'] = result.get(
            'formula_unique_count', 0.0)
        step_result['formula_total_count'] = result.get(
            'formula_total_count', 0.0)

    if compute_ngrams:
        step_responses = responses_by_step.get(step, [])
        for n in ngram_sizes:
            step_result[f'distinct_{n}gram_count'] = result.get(
                f'distinct_{n}gram_count')
            step_result[f'distinct_{n}gram_ratio'] = result.get(
                f'distinct_{n}gram_ratio')

            step_hashes, step_total_ngrams = compute_corpus_ngram_hashes(
                step_responses, n=n)
            step_union_distinct = len(step_hashes)
            step_result[f'step_global_distinct_{n}gram_count'] = step_union_distinct
            step_result[f'global_distinct_{n}gram_count'] = global_distinct_counts.get(
                n, 0)

            global_total_ngrams = global_total_ngram_counts.get(n, 0)
            step_result[f'step_global_distinct_{n}gram_ratio'] = (
                step_union_distinct / global_total_ngrams) if global_total_ngrams > 0 else 0.0
            step_result[f'step_total_{n}gram_count'] = step_total_ngrams
            step_result[f'global_total_{n}gram_count'] = global_total_ngrams

    return step_result


def _analyze_step_file_task(
    step,
    file_path,
    filename,
    window_size,
    max_repetitions_limit,
    show_examples,
    compute_ngrams,
    ngram_sizes,
    tokenizer_path,
    compute_sentence_diversity,
    sentence_diversity_max_len,
    metrics_path,
):
    tokenizer = _get_cached_tokenizer(tokenizer_path)
    result = analyze_repetition(
        file_path,
        window_size,
        max_repetitions_limit,
        verbose=False,
        show_examples=show_examples,
        compute_ngrams=compute_ngrams,
        ngram_sizes=ngram_sizes,
        tokenizer=tokenizer,
        compute_sentence_diversity=compute_sentence_diversity,
        sentence_diversity_max_len=sentence_diversity_max_len,
        return_response_details=bool(metrics_path),
    )
    if result and metrics_path and result.get('response_details'):
        save_response_metrics(result['response_details'], metrics_path)
        result = dict(result)
        result.pop('response_details', None)
    return step, filename, result


def compare_two_directories(dir1, dir2, pattern="*_16384.jsonl", ngram_sizes=[3, 5], method='js', verbose=True):
    """
    对比两个目录中对应文件的 n-gram 分布差异

    Args:
        dir1: 第一个目录路径
        dir2: 第二个目录路径
        pattern: 文件匹配模式
        ngram_sizes: n-gram 大小列表
        method: 散度计算方法
        verbose: 是否输出详细信息

    Returns:
        list: 每个 step 的对比结果
    """
    # 找到两个目录中的所有匹配文件
    files1 = glob.glob(os.path.join(dir1, pattern))
    files2 = glob.glob(os.path.join(dir2, pattern))

    if not files1:
        print(f"错误: 在 {dir1} 中没有找到匹配 {pattern} 的文件")
        return None
    if not files2:
        print(f"错误: 在 {dir2} 中没有找到匹配 {pattern} 的文件")
        return None

    # 提取 step 映射
    step_to_file1 = {}
    for file_path in files1:
        filename = os.path.basename(file_path)
        try:
            step = int(filename.split('_')[0])
            step_to_file1[step] = file_path
        except ValueError:
            continue

    step_to_file2 = {}
    for file_path in files2:
        filename = os.path.basename(file_path)
        try:
            step = int(filename.split('_')[0])
            step_to_file2[step] = file_path
        except ValueError:
            continue

    # 找到共同的 steps
    common_steps = sorted(set(step_to_file1.keys()) &
                          set(step_to_file2.keys()))

    if not common_steps:
        print(f"错误: 两个目录没有共同的 step 文件")
        return None

    if verbose:
        print(f"找到 {len(common_steps)} 个共同的 steps: {common_steps}")

    # 对每个 step 进行对比
    comparison_results = []
    for step in common_steps:
        file1 = step_to_file1[step]
        file2 = step_to_file2[step]

        if verbose:
            print(f"\n对比 Step {step}:")
            print(f"  文件1: {os.path.basename(file1)}")
            print(f"  文件2: {os.path.basename(file2)}")

        # 读取 responses
        responses1 = read_responses_from_file(file1)
        responses2 = read_responses_from_file(file2)

        if not responses1 or not responses2:
            if verbose:
                print(f"  警告: 某个文件为空，跳过")
            continue

        if verbose:
            print(f"  样本数: {len(responses1)} vs {len(responses2)}")

        # 计算不同 n-gram 的分布差异和熵
        result = {
            'step': step,
            'file1': os.path.basename(file1),
            'file2': os.path.basename(file2),
            'num_samples1': len(responses1),
            'num_samples2': len(responses2)
        }

        for n in ngram_sizes:
            divergence, entropy1, entropy2 = compare_sample_distributions(
                responses1, responses2, n=n, method=method, return_entropies=True
            )
            result[f'{n}gram_divergence'] = divergence
            result[f'{n}gram_entropy1'] = entropy1
            result[f'{n}gram_entropy2'] = entropy2
            if verbose:
                print(f"  {n}-gram {method.upper()} 散度: {divergence:.6f}")
                print(
                    f"  {n}-gram 熵 (Dir1): {entropy1:.6f}, 熵 (Dir2): {entropy2:.6f}")

        comparison_results.append(result)

    return comparison_results


def analyze_multiple_steps(
    directory,
    pattern="*_16384.jsonl",
    window_size=10,
    max_repetitions_limit=6,
    verbose=True,
    show_examples=False,
    compute_ngrams=False,
    ngram_sizes=[3, 5],
    tokenizer=None,
    tokenizer_path=None,
    compute_sentence_diversity=False,
    sentence_diversity_max_len=None,
    response_metrics_dir=None,
    num_workers=1,
):
    """
    分析目录下多个step文件的重复情况
    """
    # 找到所有匹配的文件
    file_pattern = os.path.join(directory, pattern)
    files = glob.glob(file_pattern)

    if not files:
        print(f"错误: 在 {directory} 中没有找到匹配 {pattern} 的文件")
        return None

    # 提取 step 信息
    step_files = []
    for file_path in files:
        filename = os.path.basename(file_path)
        try:
            step = int(filename.split('_')[0])
        except ValueError:
            if verbose:
                print(f"警告: 无法从文件名 {filename} 提取step编号，跳过")
            continue
        step_files.append((step, file_path, filename))

    if not step_files:
        print(f"错误: 在 {directory} 中未能从文件名提取任何 step")
        return None

    responses_by_step = {}
    if compute_ngrams:
        for step, file_path, _filename in step_files:
            responses_by_step[step] = read_responses_from_file(file_path)

    # --- 全局维度的 n-gram 统计（跨所有 step 聚合） ---
    # global_distinct_counts: n -> |Union(all steps)|（去重后的 distinct 数）
    # global_total_ngram_counts: n -> Total n-grams across all steps（含重复的总数）
    global_distinct_counts = {}        # n -> |Union(all steps)|
    global_total_ngram_counts = {}     # n -> total n-grams (with duplicates)
    # n -> set of distinct n-gram tuples (用于保存到文件)
    global_distinct_ngrams = {}
    if compute_ngrams:
        all_responses = []
        for step, _, _ in step_files:
            all_responses.extend(responses_by_step.get(step, []))
        for n in ngram_sizes:
            global_hashes, global_total = compute_corpus_ngram_hashes(
                all_responses, n=n)
            global_distinct_counts[n] = len(global_hashes)
            global_total_ngram_counts[n] = global_total
            # 收集实际的 distinct n-gram 文本（用于保存到文件）
            global_distinct_ngrams[n], _ = compute_corpus_ngram_texts(
                all_responses, n=n)
        if verbose:
            print("  全局 n-gram 统计（跨所有 step 聚合）:")
            for n in ngram_sizes:
                print(
                    f"    global distinct {n}-gram count = {global_distinct_counts.get(n, 0)}")
                print(
                    f"    global total    {n}-gram count = {global_total_ngram_counts.get(n, 0)}")

    worker_count = max(1, int(num_workers or 1))
    resolved_tokenizer_path = tokenizer_path or getattr(
        tokenizer, "name_or_path", None)
    deferred_tokenizer = None
    if worker_count > 1:
        if tokenizer is not None:
            deferred_tokenizer = tokenizer
        elif resolved_tokenizer_path and AutoTokenizer is not None:
            deferred_tokenizer = _get_cached_tokenizer(resolved_tokenizer_path)
        elif resolved_tokenizer_path and AutoTokenizer is None and verbose:
            print("[warn] transformers 不可用，跳过 token length 后处理。")
        if verbose and deferred_tokenizer is not None:
            print("[info] tokenizer 将在并行分析结束后统一计算 token length。")

    # 逐 step 计算（保留原来的“每条 response 内 distinct ratio 的平均”，并新增“全局维度”指标）
    step_results = []
    if worker_count > 1 and len(step_files) > 1:
        if verbose:
            print(
                f"使用并行模式分析 step 文件: workers={min(worker_count, len(step_files))}")
        with ProcessPoolExecutor(max_workers=min(worker_count, len(step_files))) as executor:
            future_to_meta = {}
            for step, file_path, filename in step_files:
                metrics_path = _get_response_metrics_path(
                    response_metrics_dir, step, filename)
                future = executor.submit(
                    _analyze_step_file_task,
                    step,
                    file_path,
                    filename,
                    window_size,
                    max_repetitions_limit,
                    show_examples,
                    compute_ngrams,
                    ngram_sizes,
                    None,
                    compute_sentence_diversity,
                    sentence_diversity_max_len,
                    metrics_path,
                )
                future_to_meta[future] = (step, filename)

            for future in as_completed(future_to_meta):
                step, filename = future_to_meta[future]
                task_step, task_filename, result = future.result()
                if result:
                    step_results.append(
                        _build_step_result(
                            task_step,
                            task_filename,
                            result,
                            compute_sentence_diversity,
                            compute_ngrams,
                            ngram_sizes,
                            responses_by_step,
                            global_distinct_counts,
                            global_total_ngram_counts,
                        )
                    )
                if verbose:
                    print(f"[done] Step {step} ({filename})")
    else:
        for step, file_path, filename in step_files:
            result = analyze_repetition(
                file_path,
                window_size,
                max_repetitions_limit,
                verbose=verbose,
                show_examples=show_examples,
                compute_ngrams=compute_ngrams,
                ngram_sizes=ngram_sizes,
                tokenizer=tokenizer,
                compute_sentence_diversity=compute_sentence_diversity,
                sentence_diversity_max_len=sentence_diversity_max_len,
                return_response_details=bool(response_metrics_dir),
            )
            if result:
                metrics_path = _get_response_metrics_path(
                    response_metrics_dir, step, filename)
                if metrics_path and result.get('response_details'):
                    save_response_metrics(
                        result['response_details'], metrics_path)

                step_results.append(
                    _build_step_result(
                        step,
                        filename,
                        result,
                        compute_sentence_diversity,
                        compute_ngrams,
                        ngram_sizes,
                        responses_by_step,
                        global_distinct_counts,
                        global_total_ngram_counts,
                    )
                )
            if verbose:
                print()

    if worker_count > 1 and deferred_tokenizer is not None:
        _apply_token_lengths_after_parallel(
            step_files=step_files,
            step_results=step_results,
            tokenizer=deferred_tokenizer,
            response_metrics_dir=response_metrics_dir,
        )

    # 按step排序
    step_results.sort(key=lambda x: x['step'])

    # 返回结果，包含 step_results 和 global_distinct_ngrams（如果计算了 n-grams）
    if compute_ngrams:
        return {
            'step_results': step_results,
            'global_distinct_ngrams': global_distinct_ngrams
        }
    else:
        return step_results


def get_experiment_name(directory_path):
    """
    从目录路径中提取实验名称
    """
    # 获取倒数第二级目录名作为实验名
    parts = directory_path.rstrip('/').split('/')
    if len(parts) >= 2:
        return parts[-2]
    return os.path.basename(directory_path.rstrip('/'))


def save_distinct_ngrams(global_distinct_ngrams, output_dir, exp_name=None, ngram_sizes=[6, 10, 15], max_count=1000):
    """
    将 distinct n-grams 保存到文本文件

    Args:
        global_distinct_ngrams: dict, n -> set of distinct n-gram tuples
        output_dir: 输出目录
        exp_name: 实验名称（用于文件名）
        ngram_sizes: 要保存的 n-gram 大小列表
        max_count: 每个 n-gram 大小最多保存的数量（默认: 1000）
    """
    if not global_distinct_ngrams:
        print("警告: 没有 distinct n-grams 可保存")
        return

    # 确保输出目录存在
    os.makedirs(output_dir, exist_ok=True)

    # 为每个 n-gram 大小保存一个文件
    for n in ngram_sizes:
        if n not in global_distinct_ngrams:
            print(f"警告: 没有 {n}-gram 数据可保存")
            continue

        distinct_ngrams = global_distinct_ngrams[n]
        if not distinct_ngrams:
            print(f"警告: {n}-gram 集合为空")
            continue

        # 生成文件名
        if exp_name:
            filename = f"distinct_{n}gram_{exp_name}.txt"
        else:
            filename = f"distinct_{n}gram.txt"

        filepath = os.path.join(output_dir, filename)

        # 将 n-grams 转换为字符串并排序（按字母顺序）
        ngram_strings = [" ".join(ng) for ng in distinct_ngrams]
        ngram_strings.sort()

        # 限制保存数量
        total_count = len(ngram_strings)
        ngram_strings = ngram_strings[:max_count]
        saved_count = len(ngram_strings)

        # 保存到文件
        try:
            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(
                    f"# Distinct {n}-grams (Total: {total_count}, Saved: {saved_count})\n")
                f.write(f"# Generated from analysis_repetition_penalty.py\n\n")
                for ngram_str in ngram_strings:
                    f.write(f"{ngram_str}\n")
            print(
                f"已保存 {saved_count}/{total_count} 个 distinct {n}-grams 到: {filepath}")
        except Exception as e:
            print(f"保存 {n}-gram 文件失败: {e}")


def plot_multiple_experiments(all_experiments_results, output_path='repetition_analysis.png', plot_ngrams=False, ngram_sizes=[3, 5], plot_length=False, plot_accuracy=False):
    """
    按照用户新需求绘制图表：
    1. 保存一张总的重复率对比图
    2. 为每个实验的每个 n-gram 保存一张包含长度和精度曲线的独立图
    """
    if not all_experiments_results:
        print("没有数据可以绘图")
        return

    # 获取输出目录
    output_dir = os.path.dirname(output_path)
    if not output_dir:
        output_dir = "."
    base_name = os.path.splitext(os.path.basename(output_path))[0]

    # 定义颜色和标记样式
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    markers = ['o', 's', '^', 'D', 'v', '<']

    # --- 1. 绘制总的重复率对比图 ---
    plt.figure(figsize=(12, 7))
    for idx, (exp_name, step_results) in enumerate(all_experiments_results.items()):
        steps = [r['step'] for r in step_results]
        repetition_rates = [r['repetition_rate'] for r in step_results]
        color = colors[idx % len(colors)]
        marker = markers[idx % len(markers)]

        plt.plot(steps, repetition_rates, marker=marker, linewidth=2, markersize=8,
                 label=exp_name, color=color, alpha=0.8)

    plt.xlabel('Training Step', fontsize=13, fontweight='bold')
    plt.ylabel('Average Repetition Ratio', fontsize=13, fontweight='bold')
    plt.title('Comparison of Average Repetition Ratio',
              fontsize=15, fontweight='bold', pad=20)
    plt.legend(fontsize=10, loc='best')
    plt.grid(True, alpha=0.3, linestyle='--')

    ratio_output = os.path.join(
        output_dir, f"{base_name}_repetition_ratio.png")
    plt.tight_layout()
    plt.savefig(ratio_output, dpi=300)
    plt.close()
    print(f"重复率对比图已保存到: {ratio_output}")

    # --- 新增：为每个 n-gram 绘制总的 distinct ratio 对比图 ---
    if plot_ngrams:
        for n in ngram_sizes:
            plt.figure(figsize=(12, 7))
            for idx, (exp_name, step_results) in enumerate(all_experiments_results.items()):
                steps = [r['step'] for r in step_results]
                distinct_ratios = [
                    r.get(f'distinct_{n}gram_ratio', 0) for r in step_results]
                color = colors[idx % len(colors)]
                marker = markers[idx % len(markers)]

                plt.plot(steps, distinct_ratios, marker=marker, linewidth=2, markersize=8,
                         label=exp_name, color=color, alpha=0.8)

            plt.xlabel('Training Step', fontsize=13, fontweight='bold')
            plt.ylabel(
                f'Average Distinct {n}-gram Ratio', fontsize=13, fontweight='bold')
            plt.title(
                f'Comparison of Average Distinct {n}-gram Ratio', fontsize=15, fontweight='bold', pad=20)
            plt.legend(fontsize=10, loc='best')
            plt.grid(True, alpha=0.3, linestyle='--')

            ratio_output = os.path.join(
                output_dir, f"{base_name}_{n}gram_distinct_ratio.png")
            plt.tight_layout()
            plt.savefig(ratio_output, dpi=300)
            plt.close()
            print(f"Distinct {n}-gram 比例对比图已保存到: {ratio_output}")

    # --- 2. 为每个实验生成独立的 n-gram + 长度 + 精度图 ---
    for exp_idx, (exp_name, step_results) in enumerate(all_experiments_results.items()):
        steps = [r['step'] for r in step_results]
        token_lengths = [r.get('avg_token_length', 0) for r in step_results]
        accuracies = [r.get('accuracy', 0) for r in step_results]

        if plot_ngrams:
            for n in ngram_sizes:
                distinct_counts = [
                    r.get(f'distinct_{n}gram_count', 0) for r in step_results]
                distinct_ratios = [
                    r.get(f'distinct_{n}gram_ratio', 0) for r in step_results]

                # 创建新图，带双/三/四 Y 轴
                fig, ax1 = plt.subplots(figsize=(12, 7))

                # 1. 绘制 n-gram Ratio (左轴) - 优先绘制 Ratio，因为它比较重要且范围在 0-1
                color_ratio = 'forestgreen'
                ax1.set_xlabel('Training Step', fontsize=12, fontweight='bold')
                ax1.set_ylabel(f'Ratio (Avg-per-response)',
                               color=color_ratio, fontsize=12, fontweight='bold')
                lns1 = ax1.plot(steps, distinct_ratios, marker='o',
                                color=color_ratio, linewidth=2, label=f'Avg {n}-gram Ratio')
                ax1.tick_params(axis='y', labelcolor=color_ratio)
                ax1.grid(True, alpha=0.2)

                # 2. 绘制 n-gram Count (右轴1)
                ax_count = ax1.twinx()
                color_count = 'lime'
                ax_count.set_ylabel(
                    f'Average Distinct {n}-gram Count', color=color_count, fontsize=12, fontweight='bold')
                lns_count = ax_count.plot(steps, distinct_counts, marker='h', color=color_count,
                                          linewidth=2, alpha=0.5, linestyle=':', label=f'{n}-gram Count')
                ax_count.tick_params(axis='y', labelcolor=color_count)
                lns = lns1 + lns_count

                # 3. 绘制 Token Length (右轴2)
                if plot_length:
                    ax2 = ax1.twinx()
                    # 将轴向右偏移
                    ax2.spines['right'].set_position(('outward', 60))
                    color_len = 'purple'
                    ax2.set_ylabel(
                        'Average Token Length', color=color_len, fontsize=12, fontweight='bold')
                    lns2 = ax2.plot(steps, token_lengths, marker='s', color=color_len,
                                    linewidth=2, alpha=0.6, linestyle='--', label='Token Length')
                    ax2.tick_params(axis='y', labelcolor=color_len)
                    lns += lns2

                # 4. 绘制 Accuracy (右轴3 - 偏移)
                if plot_accuracy:
                    ax3 = ax1.twinx()
                    # 将轴进一步向右偏移
                    ax3.spines['right'].set_position(('outward', 120))
                    color_acc = 'red'
                    ax3.set_ylabel('Macro Average Accuracy',
                                   color=color_acc, fontsize=12, fontweight='bold')
                    lns3 = ax3.plot(steps, accuracies, marker='*', color=color_acc,
                                    linewidth=2, alpha=0.8, label='Accuracy')
                    ax3.tick_params(axis='y', labelcolor=color_acc)
                    # 动态设置 y 轴范围，以便看清波动
                    if accuracies:
                        min_acc, max_acc = min(accuracies), max(accuracies)
                        margin = (max_acc - min_acc) * \
                            0.1 if max_acc > min_acc else 0.05
                        ax3.set_ylim(max(0, min_acc - margin),
                                     min(1.0, max_acc + margin))
                    lns += lns3

                plt.title(f'{exp_name}: {n}-gram Metrics',
                          fontsize=14, fontweight='bold', pad=15)

                # 合并所有轴的图例
                labs = [l.get_label() for l in lns]
                ax1.legend(lns, labs, loc='upper left')

                # 生成文件名
                sanitized_exp_name = exp_name.replace(
                    " ", "_").replace("/", "_")
                ngram_output = os.path.join(
                    output_dir, f"{base_name}_{sanitized_exp_name}_{n}gram_metrics.png")
                plt.tight_layout()
                plt.savefig(ngram_output, dpi=300)
                plt.close()
                print(f"实验图表已保存: {ngram_output}")


def plot_distribution_divergence(comparison_results, output_path='divergence_comparison.png', ngram_sizes=[3, 5], dir1_name='Dir1', dir2_name='Dir2', plot_entropies=True):
    """
    绘制两个目录的 n-gram 分布散度对比图

    Args:
        comparison_results: 对比结果列表
        output_path: 输出图表路径
        ngram_sizes: n-gram 大小列表
        dir1_name: 第一个目录的名称
        dir2_name: 第二个目录的名称
        plot_entropies: 是否绘制熵的变化图
    """
    if not comparison_results:
        print("没有数据可以绘图")
        return

    num_ngrams = len(ngram_sizes)

    # 如果绘制熵，创建 2 行子图（第一行是散度，第二行是熵）
    if plot_entropies:
        fig, axes = plt.subplots(2, num_ngrams, figsize=(7 * num_ngrams, 12))
        if num_ngrams == 1:
            axes = [[axes[0]], [axes[1]]]
    else:
        fig, axes_row = plt.subplots(
            1, num_ngrams, figsize=(7 * num_ngrams, 6))
        if num_ngrams == 1:
            axes_row = [axes_row]
        axes = [axes_row]

    steps = [r['step'] for r in comparison_results]

    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

    # 绘制散度
    for idx, n in enumerate(ngram_sizes):
        divergences = [r[f'{n}gram_divergence'] for r in comparison_results]

        axes[0][idx].plot(steps, divergences, marker='o',
                          linewidth=2, markersize=8, color=colors[0])
        axes[0][idx].set_xlabel(
            'Training Step', fontsize=13, fontweight='bold')
        axes[0][idx].set_ylabel(
            f'{n}-gram JS Divergence', fontsize=13, fontweight='bold')
        axes[0][idx].set_title(
            f'{n}-gram Distribution Divergence)', fontsize=14, fontweight='bold', pad=15)
        axes[0][idx].grid(True, alpha=0.3, linestyle='--')

        # 添加数值标签
        if len(steps) <= 15:
            for step, div in zip(steps, divergences):
                axes[0][idx].annotate(f'{div:.4f}',
                                      xy=(step, div),
                                      xytext=(0, 8),
                                      textcoords='offset points',
                                      ha='center',
                                      fontsize=8,
                                      alpha=0.7)

    # 绘制熵
    if plot_entropies:
        for idx, n in enumerate(ngram_sizes):
            entropies1 = [r[f'{n}gram_entropy1'] for r in comparison_results]
            entropies2 = [r[f'{n}gram_entropy2'] for r in comparison_results]

            axes[1][idx].plot(steps, entropies1, marker='s', linewidth=2, markersize=8,
                              color=colors[2], label=f'{dir1_name}', alpha=0.8)
            axes[1][idx].plot(steps, entropies2, marker='^', linewidth=2, markersize=8,
                              color=colors[3], label=f'{dir2_name}', alpha=0.8)

            axes[1][idx].set_xlabel(
                'Training Step', fontsize=13, fontweight='bold')
            axes[1][idx].set_ylabel(
                f'{n}-gram Entropy', fontsize=13, fontweight='bold')
            axes[1][idx].set_title(
                f'{n}-gram Distribution Entropy', fontsize=14, fontweight='bold', pad=15)
            axes[1][idx].legend(fontsize=10, loc='best', framealpha=0.9)
            axes[1][idx].grid(True, alpha=0.3, linestyle='--')

            # 添加数值标签（熵值通常比较大，所以可以少标一些）
            if len(steps) <= 10:
                for step, ent1, ent2 in zip(steps, entropies1, entropies2):
                    axes[1][idx].annotate(f'{ent1:.2f}',
                                          xy=(step, ent1),
                                          xytext=(0, 8),
                                          textcoords='offset points',
                                          ha='center',
                                          fontsize=7,
                                          alpha=0.6)
                    axes[1][idx].annotate(f'{ent2:.2f}',
                                          xy=(step, ent2),
                                          xytext=(0, -12),
                                          textcoords='offset points',
                                          ha='center',
                                          fontsize=7,
                                          alpha=0.6)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n图表已保存到: {output_path}")
    plt.close()


def plot_repetition_rates(step_results, output_path='repetition_analysis.png', plot_ngrams=False, ngram_sizes=[3, 5], plot_length=False, plot_accuracy=False):
    """
    绘制单个实验的重复率图表
    """
    if not step_results:
        print("没有数据可以绘图")
        return

    num_plots = 1
    if plot_ngrams:
        num_plots += len(ngram_sizes) * 2  # 每个 n-gram: Avg Count, Avg Ratio
    if plot_length:
        num_plots += 1
    if plot_accuracy:
        num_plots += 1

    fig, axes = plt.subplots(1, num_plots, figsize=(7 * num_plots, 6))
    if num_plots == 1:
        axes = [axes]

    steps = [r['step'] for r in step_results]
    repetition_rates = [r['repetition_rate'] for r in step_results]

    axes[0].plot(steps, repetition_rates, marker='o',
                 linewidth=2, markersize=8)
    axes[0].set_xlabel('Training Step', fontsize=12)
    axes[0].set_ylabel('Average Repetition Ratio', fontsize=12)
    axes[0].set_title('Average Repetition Ratio', fontsize=14)
    axes[0].grid(True, alpha=0.3)

    current_ax_idx = 1

    if plot_ngrams:
        for i, n in enumerate(ngram_sizes):
            # 绘制 Count
            distinct_counts = [
                r.get(f'distinct_{n}gram_count', 0) for r in step_results]
            axes[current_ax_idx].plot(
                steps, distinct_counts, marker='o', linewidth=2, markersize=8, color='green')
            axes[current_ax_idx].set_xlabel('Training Step', fontsize=12)
            axes[current_ax_idx].set_ylabel(
                f'Average Distinct {n}-gram Count', fontsize=12)
            axes[current_ax_idx].set_title(
                f'Average Distinct {n}-gram Count', fontsize=14)
            axes[current_ax_idx].grid(True, alpha=0.3)
            current_ax_idx += 1

            # 绘制 Ratio
            distinct_ratios = [
                r.get(f'distinct_{n}gram_ratio', 0) for r in step_results]
            axes[current_ax_idx].plot(
                steps, distinct_ratios, marker='s', linewidth=2, markersize=8, color='darkgreen')
            axes[current_ax_idx].set_xlabel('Training Step', fontsize=12)
            axes[current_ax_idx].set_ylabel(
                f'Average Distinct {n}-gram Ratio', fontsize=12)
            axes[current_ax_idx].set_title(
                f'Average Distinct {n}-gram Ratio', fontsize=14)
            axes[current_ax_idx].grid(True, alpha=0.3)
            current_ax_idx += 1

    if plot_length:
        token_lengths = [r.get('avg_token_length', 0) for r in step_results]
        axes[current_ax_idx].plot(
            steps, token_lengths, marker='o', linewidth=2, markersize=8, color='purple')
        axes[current_ax_idx].set_xlabel('Training Step', fontsize=12)
        axes[current_ax_idx].set_ylabel('Average Token Length', fontsize=12)
        axes[current_ax_idx].set_title('Average Response Length', fontsize=14)
        axes[current_ax_idx].grid(True, alpha=0.3)
        current_ax_idx += 1

    if plot_accuracy:
        accuracies = [r.get('accuracy', 0) for r in step_results]
        axes[current_ax_idx].plot(
            steps, accuracies, marker='o', linewidth=2, markersize=8, color='red')
        axes[current_ax_idx].set_xlabel('Training Step', fontsize=12)
        axes[current_ax_idx].set_ylabel('Macro Average Accuracy', fontsize=12)
        axes[current_ax_idx].set_title('Macro Average Accuracy', fontsize=14)
        axes[current_ax_idx].grid(True, alpha=0.3)
        # 动态设置 y 轴范围，以便看清波动
        if accuracies:
            min_acc, max_acc = min(accuracies), max(accuracies)
            margin = (max_acc - min_acc) * 0.1 if max_acc > min_acc else 0.05
            axes[current_ax_idx].set_ylim(
                max(0, min_acc - margin), min(1.0, max_acc + margin))

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n图表已保存到: {output_path}")
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='分析JSONL文件中response的重复情况')
    parser.add_argument('--file', '-f', type=str,
                        help='单个JSONL文件路径')
    parser.add_argument('--directory', '-d', type=str, action='append',
                        help='包含多个step文件的目录路径（可以多次使用此参数来指定多个目录）')
    parser.add_argument('--compare-dirs', type=str, nargs=2, metavar=('DIR1', 'DIR2'),
                        help='对比两个目录中对应样本的n-gram分布差异')
    parser.add_argument('--pattern', '-p', type=str, default='*_16384.jsonl',
                        help='文件匹配模式 (默认: *_16384.jsonl)')
    parser.add_argument('--window_size', '-w', type=int, default=10,
                        help='滑动窗口大小 (默认: 10)')
    parser.add_argument('--max_repetitions_limit', '-m', type=int, default=10,
                        help='重复次数阈值，超过此值的窗口被认为是重复 (默认: 6)')
    parser.add_argument('--output', '-o', type=str, default='repetition_analysis.png',
                        help='输出图表文件路径 (默认: repetition_analysis.png)')
    parser.add_argument('--show-examples', '-s', action='store_true',
                        help='显示重复文本的具体示例')
    parser.add_argument('--compute-ngrams', '-n', action='store_true',
                        help='计算 n-gram 的 distinct 个数（由 --ngram-sizes 指定，用于衡量 exploration）')
    parser.add_argument('--ngram-sizes', type=int, nargs='+', default=[10],
                        help='用于计算 distinct 个数和分布对比的 n-gram 大小 (默认: 5 10 15)')
    parser.add_argument('--divergence-method', type=str, default='js', choices=['js', 'kl'],
                        help='散度计算方法: js (Jensen-Shannon) 或 kl (KL divergence) (默认: js)')
    parser.add_argument('--plot-entropies', action='store_true',
                        help='在对比目录时绘制分布熵的变化图')
    parser.add_argument('--tokenizer', type=str,
                        help='Tokenizer 路径，用于计算 token 长度曲线')
    parser.add_argument('--save-data', type=str,
                        help='将分析结果保存到 JSON 文件')
    parser.add_argument('--load-data', type=str,
                        help='从保存的 JSON 文件加载数据直接画图')
    parser.add_argument('--plot-accuracy', action='store_true',
                        help='在图表中绘制 Accuracy 曲线')
    parser.add_argument('--no-plot', action='store_true',
                        help='禁用绘图功能')
    parser.add_argument('--save-ngrams', type=str,
                        help='保存 distinct n-grams 到指定目录（每个 n-gram 大小保存为一个 txt 文件）')
    parser.add_argument('--save-ngrams-count', type=int, default=1000,
                        help='每个 n-gram 大小最多保存的数量（默认: 1000）')
    parser.add_argument('--compute-sentence-diversity', action='store_true',
                        help='计算单条 response 内部的文本多样性和公式多样性（句内 TD/ED）')
    parser.add_argument('--sentence-diversity-max-len', type=int, default=None,
                        help='计算句内 TD/ED 时截断的最大字符数；默认不截断，按完整 response 计算')
    parser.add_argument('--num-workers', type=int, default=1,
                        help='目录模式下并行分析 step 文件的进程数（默认: 1）')
    parser.add_argument('--save-response-metrics', type=str,
                        help='保存逐条 response 的明细指标到 jsonl；目录模式下该参数应为目录路径')

    args = parser.parse_args()

    # 初始化 Tokenizer
    tokenizer = None
    defer_batch_tokenizer = bool(
        args.tokenizer and args.directory and not args.load_data)
    if args.tokenizer:
        if AutoTokenizer is None:
            print("错误: 无法导入 transformers，请安装 transformers 以使用 tokenizer 功能")
        elif defer_batch_tokenizer:
            print(f"[info] 将在所有目录分析结束后统一加载 tokenizer: {args.tokenizer}")
        else:
            print(f"正在加载 tokenizer: {args.tokenizer}")
            tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

    # 初始化 all_experiments_results
    all_experiments_results = {}

    if args.load_data:
        print(f"正在从文件加载数据: {args.load_data}")
        try:
            with open(args.load_data, 'r', encoding='utf-8') as f:
                all_experiments_results = json.load(f)
        except Exception as e:
            print(f"加载数据失败: {e}")
            sys.exit(1)

    if args.compare_dirs:
        # 对比两个目录的分布差异模式
        dir1, dir2 = args.compare_dirs

        print("="*60)
        print("目录对比模式: n-gram 分布差异 analysis")
        print("="*60)
        print(f"目录1: {dir1}")
        print(f"目录2: {dir2}")
        print(f"文件模式: {args.pattern}")
        print(f"n-gram 大小: {args.ngram_sizes}")
        print(f"散度方法: {args.divergence_method.upper()}")
        print("="*60)
        print()

        comparison_results = compare_two_directories(
            dir1, dir2,
            pattern=args.pattern,
            ngram_sizes=args.ngram_sizes,
            method=args.divergence_method,
            verbose=True
        )

        if comparison_results:
            # ... (rest of comparison_results logic remains the same)
            # 打印统计表格 - 散度
            print("\n" + "="*80)
            print("对比结果汇总 - 分布散度")
            print("="*80)

            header = f"{'Step':<10} "
            for n in args.ngram_sizes:
                header += f"{n}-gram 散度{' '*5}"
            print(header)
            print("-" * 80)

            for r in comparison_results:
                row = f"{r['step']:<10} "
                for n in args.ngram_sizes:
                    row += f"{r[f'{n}gram_divergence']:<15.6f} "
                print(row)

            print("-" * 80)
            avg_row = f"{'平均':<10} "
            for n in args.ngram_sizes:
                avg_div = np.mean([r[f'{n}gram_divergence']
                                  for r in comparison_results])
                avg_row += f"{avg_div:<15.6f} "
            print(avg_row)
            print("="*80)

            print("\n" + "="*100)
            print("对比结果汇总 - 分布熵")
            print("="*100)

            header = f"{'Step':<10} "
            for n in args.ngram_sizes:
                header += f"{n}-gram 熵1{' '*5}{n}-gram 熵2{' '*5}"
            print(header)
            print("-" * 100)

            for r in comparison_results:
                row = f"{r['step']:<10} "
                for n in args.ngram_sizes:
                    row += (
                        f"{r[f'{n}gram_entropy1']:<15.6f} "
                        f"{r[f'{n}gram_entropy2']:<15.6f} "
                    )
                print(row)

            print("-" * 100)
            avg_row = f"{'平均':<10} "
            for n in args.ngram_sizes:
                avg_ent1 = np.mean([r[f'{n}gram_entropy1']
                                   for r in comparison_results])
                avg_ent2 = np.mean([r[f'{n}gram_entropy2']
                                   for r in comparison_results])
                avg_row += f"{avg_ent1:<15.6f} {avg_ent2:<15.6f} "
            print(avg_row)
            print("="*100)

            dir1_name = get_experiment_name(dir1)
            dir2_name = get_experiment_name(dir2)
            if not args.no_plot:
                plot_distribution_divergence(
                    comparison_results,
                    output_path=args.output,
                    ngram_sizes=args.ngram_sizes,
                    dir1_name=dir1_name,
                    dir2_name=dir2_name,
                    plot_entropies=True
                )

    elif args.directory and not args.load_data:
        # 批量分析模式（可能是多个实验）
        deferred_token_length_jobs = []
        print("="*60)
        print("批量分析模式")
        print("="*60)
        print(f"模式: {args.pattern}")
        print(f"窗口大小: {args.window_size}")
        print(f"重复阈值: {args.max_repetitions_limit}")
        print(f"计算n-gram: {args.compute_ngrams}")
        print(f"计算句内TD/ED: {args.compute_sentence_diversity}")
        print(f"并行进程数: {args.num_workers}")
        print(f"实验数量: {len(args.directory)}")
        print("="*60)
        print()

        for directory in args.directory:
            exp_name = get_experiment_name(directory)
            print(f"\n{'='*60}")
            print(f"分析实验: {exp_name}")
            print(f"目录: {directory}")
            print(f"{'='*60}")

            response_metrics_dir = None
            if args.save_response_metrics:
                response_metrics_dir = os.path.join(
                    args.save_response_metrics,
                    exp_name.replace(" ", "_").replace("/", "_")
                )

            result = analyze_multiple_steps(
                directory,
                args.pattern,
                args.window_size,
                args.max_repetitions_limit,
                verbose=True,
                show_examples=args.show_examples,
                compute_ngrams=args.compute_ngrams,
                ngram_sizes=args.ngram_sizes,
                tokenizer=None if defer_batch_tokenizer else tokenizer,
                tokenizer_path=None if defer_batch_tokenizer else args.tokenizer,
                compute_sentence_diversity=args.compute_sentence_diversity,
                sentence_diversity_max_len=args.sentence_diversity_max_len,
                response_metrics_dir=response_metrics_dir,
                num_workers=args.num_workers,
            )

            # 处理返回格式（可能是字典或列表）
            if isinstance(result, dict):
                step_results = result['step_results']
                global_distinct_ngrams = result.get(
                    'global_distinct_ngrams', {})
            else:
                step_results = result
                global_distinct_ngrams = {}

            if step_results:
                all_experiments_results[exp_name] = step_results
                if defer_batch_tokenizer and args.tokenizer and AutoTokenizer is not None:
                    deferred_token_length_jobs.append({
                        'directory': directory,
                        'pattern': args.pattern,
                        'step_results': step_results,
                        'response_metrics_dir': response_metrics_dir,
                    })

                # 保存 distinct n-grams 到文件
                if args.save_ngrams and global_distinct_ngrams:
                    save_dir = args.save_ngrams
                    save_distinct_ngrams(global_distinct_ngrams, save_dir, exp_name=exp_name,
                                         ngram_sizes=args.ngram_sizes, max_count=args.save_ngrams_count)

                # 打印该实验的统计
                print(f"\n{exp_name} - 统计结果:")
                header = f"{'Step':<10} {'文件名':<25} {'重复比例':<15} "
                if args.compute_sentence_diversity:
                    header += f"{'句内TD':<12} {'句内ED':<12} {'切分单元':<12} "
                if args.compute_ngrams:
                    for n in args.ngram_sizes:
                        header += f"{n}-gram Count{' '*(15-len(str(n))-11)} "
                        header += f"{n}-gram Ratio{' '*(15-len(str(n))-11)} "
                        header += f"{n}-gram StepDistinct{' '*(18-len(str(n))-16)} "
                        header += f"{n}-gram StepD/GlobalTotal{' '*(24-len(str(n))-20)} "
                        header += f"{n}-gram StepTotal{' '*(18-len(str(n))-13)} "
                        header += f"{n}-gram TotalRatio{' '*(18-len(str(n))-14)} "
                show_token_length = (
                    (not defer_batch_tokenizer and tokenizer is not None)
                    or any(r.get('avg_token_length', 0) > 0 for r in step_results)
                )
                if show_token_length:
                    header += f"{'Token长度':<12} "
                header += f"{'Accuracy':<10} "
                header += f"{'总样本数':<10}"
                print(header)
                print("-" * len(header))
                for r in step_results:
                    row = f"{r['step']:<10} {r['file']:<25} {r['repetition_rate']:<15.4f} "
                    if args.compute_sentence_diversity:
                        row += f"{r.get('internal_textual_diversity', 0):<12.4f} "
                        row += f"{r.get('internal_equational_diversity', 0):<12.4f} "
                        row += f"{r.get('internal_unit_count', 0):<12.2f} "
                    if args.compute_ngrams:
                        for n in args.ngram_sizes:
                            row += f"{r.get(f'distinct_{n}gram_count', 0):<15.2f} "
                            row += f"{r.get(f'distinct_{n}gram_ratio', 0):<15.4f} "
                            row += f"{r.get(f'step_distinct_{n}gram_count', 0):<18.2f} "
                            row += f"{r.get(f'step_distinct_{n}gram_ratio', 0):<24.6f} "
                            row += f"{r.get(f'step_total_{n}gram_count', 0):<18.2f} "
                            row += f"{r.get(f'global_total_{n}gram_ratio', 0):<18.4f} "
                    if show_token_length:
                        row += f"{r.get('avg_token_length', 0):<12.2f} "
                    row += f"{r.get('accuracy', 0):<10.4f} "
                    row += f"{r['total_responses']:<10}"
                    print(row)

        if defer_batch_tokenizer and deferred_token_length_jobs:
            print("\n" + "="*60)
            print("统一计算 Token Length")
            print("="*60)
            if tokenizer is None and AutoTokenizer is not None:
                print(f"正在加载 tokenizer: {args.tokenizer}")
                tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
            if tokenizer is not None:
                for job in deferred_token_length_jobs:
                    print(f"[tokenizer] {job['directory']}")
                    step_files = _collect_step_files(
                        job['directory'],
                        pattern=job['pattern'],
                        verbose=False,
                    )
                    _apply_token_lengths_after_parallel(
                        step_files=step_files,
                        step_results=job['step_results'],
                        tokenizer=tokenizer,
                        response_metrics_dir=job['response_metrics_dir'],
                    )

        if args.save_data and all_experiments_results:
            print(f"\n正在将分析结果保存到: {args.save_data}")
            with open(args.save_data, 'w', encoding='utf-8') as f:
                json.dump(all_experiments_results, f,
                          indent=2, ensure_ascii=False)

    # 绘图逻辑（无论是刚分析的还是加载的）
    if all_experiments_results and (args.directory or args.load_data) and not args.no_plot:
        print("\n" + "="*60)
        print("汇总对比与绘图")
        print("="*60)
        for exp_name, results in all_experiments_results.items():
            avg_rate = np.mean([r['repetition_rate'] for r in results])
            avg_acc = np.mean([r.get('accuracy', 0) for r in results])
            print(f"{exp_name}: 平均重复比例 = {avg_rate:.4f} ({avg_rate*100:.2f}%)")
            print(f"{exp_name}: 平均Accuracy = {avg_acc:.4f}")
            if tokenizer or any('avg_token_length' in r for r in results):
                avg_len = np.mean([r.get('avg_token_length', 0)
                                  for r in results])
                print(f"{exp_name}: 平均Token长度 = {avg_len:.2f}")
            if args.compute_sentence_diversity:
                avg_internal_td = np.mean(
                    [r.get('internal_textual_diversity', 0) for r in results])
                avg_internal_ed = np.mean(
                    [r.get('internal_equational_diversity', 0) for r in results])
                print(f"{exp_name}: 平均句内文本多样性 = {avg_internal_td:.4f}")
                print(f"{exp_name}: 平均句内公式多样性 = {avg_internal_ed:.4f}")
            if args.compute_ngrams:
                for n in args.ngram_sizes:
                    avg_ngram = np.mean(
                        [r.get(f'distinct_{n}gram_count', 0) for r in results])
                    avg_ratio = np.mean(
                        [r.get(f'distinct_{n}gram_ratio', 0) for r in results])
                    print(
                        f"{exp_name}: 平均{n}-gram distinct个数 = {avg_ngram:.2f}, 比例 = {avg_ratio:.4f}")
                    avg_stepd_globaltotal = np.mean(
                        [r.get(f'step_distinct_{n}gram_ratio', 0) for r in results])
                    global_total = results[0].get(
                        f'global_total_{n}gram_count', 0) if results else 0
                    print(
                        f"{exp_name}: 平均(step distinct / global total) = {avg_stepd_globaltotal:.6f} (global total n-grams = {global_total})")

        if len(all_experiments_results) > 1:
            plot_multiple_experiments(all_experiments_results, args.output, plot_ngrams=args.compute_ngrams, ngram_sizes=args.ngram_sizes, plot_length=(
                tokenizer is not None or any('avg_token_length' in r for res in all_experiments_results.values() for r in res)), plot_accuracy=args.plot_accuracy)
        else:
            # 只有一个实验，使用单实验绘图
            plot_repetition_rates(list(all_experiments_results.values())[0], args.output, plot_ngrams=args.compute_ngrams, ngram_sizes=args.ngram_sizes, plot_length=(
                tokenizer is not None or any('avg_token_length' in r for r in list(all_experiments_results.values())[0])), plot_accuracy=args.plot_accuracy)

    elif args.file:
        # 单文件分析模式
        print("="*60)
        print("单文件分析模式")
        print("="*60)
        result = analyze_repetition(
            args.file,
            args.window_size,
            args.max_repetitions_limit,
            verbose=True,
            show_examples=args.show_examples,
            compute_ngrams=args.compute_ngrams,
            ngram_sizes=args.ngram_sizes,
            tokenizer=tokenizer,
            compute_sentence_diversity=args.compute_sentence_diversity,
            sentence_diversity_max_len=args.sentence_diversity_max_len,
            return_response_details=bool(args.save_response_metrics),
        )

        if result:
            print("\n" + "="*60)
            print("分析结果")
            print("="*60)
            print(f"总response数量: {result['total_responses']}")
            print(
                f"平均重复比例: {result['repetition_rate']:.4f} ({result['repetition_rate']*100:.2f}%)")
            print(f"Accuracy (Macro Avg): {result['accuracy']:.4f}")
            if tokenizer:
                print(f"平均Token长度: {result.get('avg_token_length', 0):.2f}")
            if args.compute_sentence_diversity:
                print(
                    f"平均句内文本多样性 (intra-TD): {result.get('internal_textual_diversity', 0):.4f}")
                print(
                    f"平均句内文本相似度 (intra-BLEU): {result.get('internal_textual_similarity', 0):.4f}")
                print(f"平均句内切分单元数: {result.get('internal_unit_count', 0):.2f}")
                print(
                    f"平均句内公式多样性 (intra-ED): {result.get('internal_equational_diversity', 0):.4f}")
                print(
                    f"平均句内唯一公式数/总公式数: {result.get('formula_unique_count', 0):.2f}/{result.get('formula_total_count', 0):.2f}")
            if args.compute_ngrams:
                for n in args.ngram_sizes:
                    print(
                        f"平均{n}-gram distinct个数: {result.get(f'distinct_{n}gram_count', 0):.2f}")
                    print(
                        f"平均{n}-gram distinct比例: {result.get(f'distinct_{n}gram_ratio', 0):.4f}")
            if args.show_examples and result.get('repetition_examples'):
                print(f"检测到重复的样例数: {len(result['repetition_examples'])}")
            if args.save_response_metrics and result.get('response_details'):
                response_metrics_path = args.save_response_metrics
                if os.path.isdir(response_metrics_path) or not response_metrics_path.endswith('.jsonl'):
                    os.makedirs(response_metrics_path, exist_ok=True)
                    file_stem = os.path.splitext(
                        os.path.basename(args.file))[0]
                    response_metrics_path = os.path.join(
                        response_metrics_path, f"{file_stem}_response_metrics.jsonl")
                save_response_metrics(
                    result['response_details'], response_metrics_path)
            print("="*60)

    else:
        if not args.load_data and not args.directory and not args.compare_dirs:
            parser.print_help()
            print("\n示例用法:")
            print("  单文件分析:")
            print("    python analysis_repetition_penalty.py -f /path/to/file.jsonl")
            print("\n  批量分析并保存数据:")
            print("    python analysis_repetition_penalty.py -d ./exp1 -d ./exp2 --save-data results.json --plot-accuracy")
            print("\n  从保存的数据直接绘图:")
            print("    python analysis_repetition_penalty.py --load-data results.json --plot-accuracy --compute-ngrams -o comparison.png")
