import json
import sys
import argparse
import os
import glob
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import jensenshannon
from scipy.stats import entropy
from collections import Counter
import re
import unicodedata
try:
    from transformers import AutoTokenizer
except ImportError:
    AutoTokenizer = None



def detect_repetition_with_hash(text, window_size=10, return_text=False, max_repetitions_limit=6):
    """
    Use hashing to efficiently detect repeated n-grams (split by space and underscore).
    统计重复次数>=max_repetitions_limit的窗口数量，并归一化（除以总窗口数）
    
    Returns:
        如果return_text=False: (repetition_ratio, repetition_count)
        如果return_text=True: (repetition_ratio, repetition_count, repeated_text)
        
        repetition_ratio: 重复窗口数/总窗口数（归一化后的比例）
        repetition_count: 重复窗口的数量
        repeated_text: 重复次数最多的文本片段
    """
    # Split text by both space and underscore
    words = []
    for segment in text.split():
        words.extend(segment.split('_'))
    
    if len(words) <= window_size:
        if return_text:
            return 0.0, 0, ""
        return 0.0, 0
    
    hash_counts = {}
    hash_to_window = {}  # 记录每个hash对应的窗口文本
    max_repetitions = 0
    max_repeated_window = None
    
    # 第一遍：统计所有窗口的出现次数
    total_windows = len(words) - window_size + 1
    
    for i in range(total_windows):
        # Get window and its hash
        window = tuple(words[i:i+window_size])
        window_hash = hash(window)
        
        # 第一次遇到这个窗口时，记录它
        if window_hash not in hash_to_window:
            hash_to_window[window_hash] = window
        
        # Update count for this hash
        hash_counts[window_hash] = hash_counts.get(window_hash, 0) + 1
        
        # Update max repetitions
        if hash_counts[window_hash] > max_repetitions:
            max_repetitions = hash_counts[window_hash]
            max_repeated_window = window
    
    # 第二遍：统计重复次数>=max_repetitions_limit的窗口数量
    repetition_count = 0
    for i in range(total_windows):
        window = tuple(words[i:i+window_size])
        window_hash = hash(window)
        if hash_counts[window_hash] >= max_repetitions_limit:
            repetition_count += 1
    
    # 归一化：重复窗口数 / 总窗口数
    repetition_ratio = repetition_count / total_windows if total_windows > 0 else 0.0
    
    if return_text:
        repeated_text = " ".join(max_repeated_window) if max_repeated_window else ""
        return repetition_ratio, repetition_count, repeated_text
    
    return repetition_ratio, repetition_count

def compute_distinct_ngrams(text, n=3):
    """
    计算文本中 distinct n-gram 的数量和比例，用于衡量 exploration（探索性）
    使用类似 detect_repetition_with_hash 的方法
    
    Args:
        text: 输入文本
        n: n-gram 的大小（例如 3 表示 3-gram）
    
    Returns:
        (distinct_count, total_count, distinct_ratio)
        distinct_count: distinct n-gram 的数量
        total_count: 总 n-gram 数量
        distinct_ratio: distinct_count / total_count（归一化后的比例）
    """
    # Split text by both space and underscore
    words = []
    for segment in text.split():
        words.extend(segment.split('_'))
    
    if len(words) < n:
        return 1, 1, 1.0
    
    # 使用 set 来存储 distinct n-grams（使用 hash）
    distinct_ngrams = set()
    total_ngrams = len(words) - n + 1
    
    for i in range(total_ngrams):
        ngram = tuple(words[i:i+n])
        ngram_hash = hash(ngram)
        distinct_ngrams.add(ngram_hash)
    
    distinct_count = len(distinct_ngrams)
    distinct_ratio = distinct_count / total_ngrams if total_ngrams > 0 else 0.0
    
    return distinct_count, total_ngrams, distinct_ratio

def compute_corpus_ngram_hashes(texts, n=3):
    """
    计算一组文本(一个 step 或者全局所有 step)的 n-gram 统计，用于“全局维度”的 distinct ratio。
    这里的 distinct 以 n-gram 的 hash 表示（与 compute_distinct_ngrams 保持一致）。
    
    Args:
        texts: 文本列表
        n: n-gram 大小
    
    Returns:
        (distinct_hashes, total_ngrams)
        distinct_hashes: set，语料中出现过的 distinct n-gram(hash) 集合
        total_ngrams: int，语料中 n-gram 的总数量(含重复)
    """
    distinct_hashes = set()
    total_ngrams = 0
    if not texts:
        return distinct_hashes, 0
    
    for text in texts:
        ngrams = extract_ngrams_from_text(text, n)
        total_ngrams += len(ngrams)
        for ng in ngrams:
            distinct_hashes.add(hash(ng))
    
    return distinct_hashes, total_ngrams

def extract_ngrams_from_text(text, n=3):
    """
    从文本中提取所有 n-grams
    
    Args:
        text: 输入文本
        n: n-gram 的大小
    
    Returns:
        list of tuples: 所有 n-grams 的列表
    """
    words = []
    for segment in text.split():
        words.extend(segment.split('_'))
    
    if len(words) < n:
        return []
    
    ngrams = []
    for i in range(len(words) - n + 1):
        ngram = tuple(words[i:i+n])
        ngrams.append(ngram)
    
    return ngrams

def compute_ngram_distribution(texts, n=3):
    """
    计算文本集合的 n-gram 分布
    
    Args:
        texts: 文本列表
        n: n-gram 的大小
    
    Returns:
        Counter: n-gram 的频率分布
    """
    all_ngrams = []
    for text in texts:
        ngrams = extract_ngrams_from_text(text, n)
        all_ngrams.extend(ngrams)
    
    return Counter(all_ngrams)

def compute_distribution_entropy(dist):
    """
    计算单个分布的熵
    
    Args:
        dist: Counter 对象，频率分布
    
    Returns:
        float: 熵值（使用自然对数）
    """
    if len(dist) == 0:
        return 0.0
    
    total = sum(dist.values())
    if total == 0:
        return 0.0
    
    # 构建概率向量
    probabilities = np.array([count / total for count in dist.values()])
    
    # 计算熵（使用 scipy.stats.entropy，默认使用自然对数）
    # 熵越高，表示分布越均匀，越不确定
    return entropy(probabilities)

def compute_distribution_divergence(dist1, dist2, method='js', return_entropies=False):
    """
    计算两个分布之间的散度
    
    Args:
        dist1: Counter 对象，第一个分布
        dist2: Counter 对象，第二个分布
        method: 'js' for Jensen-Shannon divergence, 'kl' for KL divergence
        return_entropies: 是否同时返回两个分布的熵
    
    Returns:
        如果 return_entropies=False: float 散度值
        如果 return_entropies=True: tuple (散度值, 熵1, 熵2)
    """
    # 获取所有可能的 n-grams
    all_ngrams = set(dist1.keys()) | set(dist2.keys())
    
    if len(all_ngrams) == 0:
        if return_entropies:
            return 0.0, 0.0, 0.0
        return 0.0
    
    # 创建概率分布向量
    total1 = sum(dist1.values())
    total2 = sum(dist2.values())
    
    if total1 == 0 or total2 == 0:
        if return_entropies:
            return 0.0, 0.0, 0.0
        return 0.0
    
    # 构建对齐的概率向量
    p = np.array([dist1.get(ng, 0) / total1 for ng in all_ngrams])
    q = np.array([dist2.get(ng, 0) / total2 for ng in all_ngrams])
    
    # 计算各自的熵（在平滑之前，使用原始概率）
    entropy1 = entropy(p[p > 0])  # 只计算非零概率的熵
    entropy2 = entropy(q[q > 0])
    
    # 添加平滑以避免除零
    epsilon = 1e-10
    p = p + epsilon
    q = q + epsilon
    p = p / p.sum()
    q = q / q.sum()
    
    if method == 'js':
        # Jensen-Shannon divergence (取值范围 [0, 1])
        divergence = jensenshannon(p, q)
    elif method == 'kl':
        # KL divergence (可能为无穷大)
        divergence = entropy(p, q)
    else:
        raise ValueError(f"Unknown method: {method}")
    
    if return_entropies:
        return divergence, entropy1, entropy2
    return divergence

def compare_sample_distributions(responses1, responses2, n=3, method='js', return_entropies=False):
    """
    比较两个样本集合的 n-gram 分布差异
    
    Args:
        responses1: 第一个样本集合（文本列表）
        responses2: 第二个样本集合（文本列表）
        n: n-gram 的大小
        method: 散度计算方法
        return_entropies: 是否同时返回两个分布的熵
    
    Returns:
        如果 return_entropies=False: float 散度值
        如果 return_entropies=True: tuple (散度值, 熵1, 熵2)
    """
    dist1 = compute_ngram_distribution(responses1, n)
    dist2 = compute_ngram_distribution(responses2, n)
    
    if return_entropies:
        divergence, entropy1, entropy2 = compute_distribution_divergence(dist1, dist2, method, return_entropies=True)
        return divergence, entropy1, entropy2
    else:
        divergence = compute_distribution_divergence(dist1, dist2, method, return_entropies=False)
        return divergence

def analyze_repetition(file_path, window_size=10, max_repetitions_limit=6, verbose=True, show_examples=False, compute_ngrams=False, ngram_sizes=[3, 5], tokenizer=None):
    """
    分析jsonl文件中response的重复情况，并计算 Accuracy (Macro Avg)
    """
    responses = []
    scores = [] # 用于计算 macro average accuracy
    
    if verbose:
        print(f"正在读取文件: {file_path}")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())
                    
                    # 获取 response（兼容 train data: responses 是一个列表）
                    resp_list = data.get('responses')
                    if isinstance(resp_list, list):
                        # train data: 每条 response 都作为一个独立样本进行公平对比
                        for resp in resp_list:
                            if isinstance(resp, str) and resp:
                                responses.append(resp)
                    else:
                        resp = data.get('output') or data.get('response')
                        if resp is None:
                            continue
                        responses.append(resp)
                    
                    # 获取 accurcies 和 data_source 用于计算 accuracy
                    # 约定：data["accurcies"] 是一个列表，这里直接取平均值作为该样本的 accuracy 分数
                    accurcies = data["accuracies"] if "accuracies" in data else data["final_scores"]
                    # breakpoint()
                    
                    scores += accurcies
                        
                            
                except json.JSONDecodeError as e:
                    if verbose:
                        print(f"解析JSON失败: {e}")
                    continue
    except FileNotFoundError:
        print(f"错误: 文件未找到: {file_path}")
        return None
    
    if verbose:
        print(f"  总共读取了 {len(responses)} 个response")
    
    # 对每个response检测重复
    repetition_ratio_list = []  # 存储每个response的重复比例
    repetition_examples = []  # 存储重复文本示例
    token_lengths = []  # 存储每个response的token长度
    
    # 批量计算 token 长度以加快速度
    if tokenizer and responses:
        if verbose:
            print(f"  正在批量计算 {len(responses)} 个 response 的 token 长度...")
        # 批量编码，不进行 padding 和 truncation，只为了获取长度
        batch_encodings = tokenizer(responses, add_special_tokens=False, padding=False, truncation=False)
        token_lengths = [len(ids) for ids in batch_encodings['input_ids']]
    
    # n-gram 分析相关
    ngram_distinct_counts = {n: [] for n in ngram_sizes}
    ngram_distinct_ratios = {n: [] for n in ngram_sizes}
    
    for i, response in enumerate(responses):
        if show_examples:
            repetition_ratio, repetition_count, repeated_text = detect_repetition_with_hash(
                response, 
                window_size=window_size,
                return_text=True,
                max_repetitions_limit=max_repetitions_limit
            )
            if repetition_ratio > 0.05:  # 只记录重复比例 > 5% 的
                repetition_examples.append({
                    'index': i,
                    'ratio': repetition_ratio,
                    'count': repetition_count,
                    'repeated_text': repeated_text,
                    'response_preview': response[:300] + '...' if len(response) > 300 else response
                })
        else:
            repetition_ratio, repetition_count = detect_repetition_with_hash(
                response, 
                window_size=window_size,
                max_repetitions_limit=max_repetitions_limit
            )
        repetition_ratio_list.append(repetition_ratio)
        
        # 计算 distinct n-grams（如果启用）
        if compute_ngrams:
            for n in ngram_sizes:
                distinct_count, _, distinct_ratio = compute_distinct_ngrams(response, n=n)
                ngram_distinct_counts[n].append(distinct_count)
                ngram_distinct_ratios[n].append(distinct_ratio)
    
    # 计算统计信息
    total_responses = len(repetition_ratio_list)
    avg_repetition_ratio = np.mean(repetition_ratio_list) if total_responses > 0 else 0.0
    avg_token_length = np.mean(token_lengths) if token_lengths else 0.0
    
    # 计算 Macro Average Accuracy
    macro_acc = np.mean(scores) if scores else 0.0
    
    # 计算 n-gram 统计信息
    avg_ngram_distinct_counts = {}
    avg_ngram_distinct_ratios = {}
    if compute_ngrams:
        for n in ngram_sizes:
            avg_count = np.mean(ngram_distinct_counts[n]) if ngram_distinct_counts[n] else 0.0
            avg_ratio = np.mean(ngram_distinct_ratios[n]) if ngram_distinct_ratios[n] else 0.0
            avg_ngram_distinct_counts[n] = avg_count
            avg_ngram_distinct_ratios[n] = avg_ratio
            if verbose:
                print(f"  平均{n}-gram distinct个数: {avg_count:.2f}, 比例: {avg_ratio:.4f}")
    
    # 显示重复文本示例
    if show_examples and repetition_examples:
        print(f"\n  重复文本示例 (前5个，按重复比例降序):")
        print("  " + "-"*60)
        # 按重复比例排序
        repetition_examples.sort(key=lambda x: x['ratio'], reverse=True)
        for example in repetition_examples[:5]:
            print(f"  样例 #{example['index']}: 重复比例={example['ratio']:.2%}, 重复窗口数={example['count']}")
            print(f"  重复片段: \"{example['repeated_text']}\"")
            print("  " + "-"*60)
    
    result = {
        'total_responses': total_responses,
        'repetition_rate': avg_repetition_ratio,  # 保持兼容性，这里是平均重复比例
        'repetition_examples': repetition_examples if show_examples else [],
        'avg_token_length': avg_token_length,
        'accuracy': macro_acc
    }
    
    if compute_ngrams:
        for n, avg_count in avg_ngram_distinct_counts.items():
            result[f'distinct_{n}gram_count'] = avg_count
            result[f'distinct_{n}gram_ratio'] = avg_ngram_distinct_ratios[n]
            
    return result

def parse_step_from_filename(filename):
    """
    从文件名中解析 step 编号，兼容两种格式：
    - "10_16384.jsonl" -> 10
    - "step_10_traindata.jsonl" -> 10
    """
    # 先做一次 unicode 规范化，避免全角下划线/不明确字符导致的匹配失败
    name = unicodedata.normalize("NFKC", filename).strip()
    name_lower = name.lower()
    
    # 1) 优先匹配 "step" 之后的数字（对分隔符不敏感）
    m = re.search(r"step\D*(\d+)", name_lower)
    if m:
        return int(m.group(1))
    
    # 2) 兼容旧格式：文件名以数字开头
    m = re.match(r"^(\d+)", name_lower)
    if m:
        return int(m.group(1))
    
    # 3) 最后兜底：取文件名中出现的第一个数字串
    m = re.search(r"(\d+)", name_lower)
    if m:
        return int(m.group(1))
    
    return None

def read_responses_from_file(file_path):
    """
    从 jsonl 文件读取所有 responses
    
    Args:
        file_path: 文件路径
    
    Returns:
        list: response 列表
    """
    responses = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())
                    # train data: responses 是列表
                    if isinstance(data.get('responses'), list):
                        for resp in data['responses']:
                            if isinstance(resp, str) and resp:
                                responses.append(resp)
                    elif 'output' in data:
                        responses.append(data['output'])
                    elif 'response' in data:
                        responses.append(data['response'])
                except (json.JSONDecodeError, KeyError):
                    continue
    except FileNotFoundError:
        return []
    return responses

def compare_two_directories(dir1, dir2, pattern="*_16384.jsonl", ngram_sizes=[3, 5], method='js', verbose=True):
    """
    对比两个目录中对应文件的 n-gram 分布差异
    
    Args:
        dir1: 第一个目录路径
        dir2: 第二个目录路径
        pattern: 文件匹配模式
        ngram_sizes: n-gram 大小列表
        method: 散度计算方法
        verbose: 是否输出详细信息
    
    Returns:
        list: 每个 step 的对比结果
    """
    # 找到两个目录中的所有匹配文件
    files1 = glob.glob(os.path.join(dir1, pattern))
    files2 = glob.glob(os.path.join(dir2, pattern))
    
    if not files1:
        print(f"错误: 在 {dir1} 中没有找到匹配 {pattern} 的文件")
        return None
    if not files2:
        print(f"错误: 在 {dir2} 中没有找到匹配 {pattern} 的文件")
        return None
    
    # 提取 step 映射
    step_to_file1 = {}
    for file_path in files1:
        filename = os.path.basename(file_path)
        step = parse_step_from_filename(filename)
        if step is None:
            continue
        step_to_file1[step] = file_path
    
    step_to_file2 = {}
    for file_path in files2:
        filename = os.path.basename(file_path)
        step = parse_step_from_filename(filename)
        if step is None:
            continue
        step_to_file2[step] = file_path
    
    # 找到共同的 steps
    common_steps = sorted(set(step_to_file1.keys()) & set(step_to_file2.keys()))
    
    if not common_steps:
        print(f"错误: 两个目录没有共同的 step 文件")
        return None
    
    if verbose:
        print(f"找到 {len(common_steps)} 个共同的 steps: {common_steps}")
    
    # 对每个 step 进行对比
    comparison_results = []
    for step in common_steps:
        file1 = step_to_file1[step]
        file2 = step_to_file2[step]
        
        if verbose:
            print(f"\n对比 Step {step}:")
            print(f"  文件1: {os.path.basename(file1)}")
            print(f"  文件2: {os.path.basename(file2)}")
        
        # 读取 responses
        responses1 = read_responses_from_file(file1)
        responses2 = read_responses_from_file(file2)
        
        if not responses1 or not responses2:
            if verbose:
                print(f"  警告: 某个文件为空，跳过")
            continue
        
        if verbose:
            print(f"  样本数: {len(responses1)} vs {len(responses2)}")
        
        # 计算不同 n-gram 的分布差异和熵
        result = {
            'step': step,
            'file1': os.path.basename(file1),
            'file2': os.path.basename(file2),
            'num_samples1': len(responses1),
            'num_samples2': len(responses2)
        }
        
        for n in ngram_sizes:
            divergence, entropy1, entropy2 = compare_sample_distributions(
                responses1, responses2, n=n, method=method, return_entropies=True
            )
            result[f'{n}gram_divergence'] = divergence
            result[f'{n}gram_entropy1'] = entropy1
            result[f'{n}gram_entropy2'] = entropy2
            if verbose:
                print(f"  {n}-gram {method.upper()} 散度: {divergence:.6f}")
                print(f"  {n}-gram 熵 (Dir1): {entropy1:.6f}, 熵 (Dir2): {entropy2:.6f}")
        
        comparison_results.append(result)
    
    return comparison_results

def analyze_multiple_steps(directory, pattern="*_16384.jsonl", window_size=10, max_repetitions_limit=6, verbose=True, show_examples=False, compute_ngrams=False, ngram_sizes=[3, 5], tokenizer=None):
    """
    分析目录下多个step文件的重复情况
    """
    # 找到所有匹配的文件
    file_pattern = os.path.join(directory, pattern)
    files = glob.glob(file_pattern)
    
    if not files:
        print(f"错误: 在 {directory} 中没有找到匹配 {pattern} 的文件")
        return None
    
    # 提取 step 信息
    step_files = []
    for file_path in files:
        filename = os.path.basename(file_path)
        step = parse_step_from_filename(filename)
        if step is None:
            if verbose:
                print(f"警告: 无法从文件名 {filename} 提取step编号，跳过")
            continue
        step_files.append((step, file_path, filename))
    
    if not step_files:
        print(f"错误: 在 {directory} 中未能从文件名提取任何 step")
        return None

    # 按 step 排序，并按间隔抽样（默认每隔 10 步取一个）
    step_files.sort(key=lambda x: x[0])
    interval = getattr(args, "step_interval", 10) if "args" in globals() else 10
    offset = getattr(args, "step_offset", 0) if "args" in globals() else 0
    if interval and interval > 1:
        sampled = [(s, p, f) for (s, p, f) in step_files if (s - offset) % interval == 0]
        # 如果刚好没有命中（例如 step 不是 0/10/20...），则兜底取最接近 offset 的一个
        if not sampled and step_files:
            sampled = [min(step_files, key=lambda x: abs(x[0] - offset))]
        if verbose:
            orig_steps = [s for (s, _, _) in step_files]
            samp_steps = [s for (s, _, _) in sampled]
            print(f"  Step 抽样: interval={interval}, offset={offset}")
            print(f"  原始 steps 数={len(orig_steps)}，抽样后 steps 数={len(samp_steps)}")
        step_files = sampled
    
    # 预读每个 step 的 responses，避免重复 IO
    responses_by_step = {}
    for step, file_path, _filename in step_files:
        responses_by_step[step] = read_responses_from_file(file_path)
    
    # --- 全局维度统计（跨“所有 step”聚合） ---
    # global_distinct_counts: n -> |Union(all steps)|（去重后的 distinct 数）
    # global_total_ngram_counts: n -> Total n-grams across all steps（含重复总数）
    global_distinct_counts = {}        # n -> |Union(all steps)|
    global_total_ngram_counts = {}     # n -> total n-grams (with duplicates)
    if compute_ngrams:
        all_responses = []
        for step, _, _ in step_files:
            all_responses.extend(responses_by_step.get(step, []))
        for n in ngram_sizes:
            global_hashes, global_total = compute_corpus_ngram_hashes(all_responses, n=n)
            global_distinct_counts[n] = len(global_hashes)
            global_total_ngram_counts[n] = global_total
        if verbose:
            print("  全局 n-gram 统计（跨所有 step 聚合）:")
            for n in ngram_sizes:
                print(f"    global distinct {n}-gram count = {global_distinct_counts.get(n, 0)}")
                print(f"    global total    {n}-gram count = {global_total_ngram_counts.get(n, 0)}")
    
    # 逐 step 计算（保留原来的“每条 response 内 distinct ratio 的平均”，并新增“全局维度”指标）
    step_results = []
    for step, file_path, filename in step_files:
        result = analyze_repetition(
            file_path,
            window_size,
            max_repetitions_limit,
            verbose=verbose,
            show_examples=show_examples,
            compute_ngrams=compute_ngrams,
            ngram_sizes=ngram_sizes,
            tokenizer=tokenizer
        )
        if result:
            step_result = {
                'step': step,
                'file': filename,
                'repetition_rate': result['repetition_rate'],
                'total_responses': result['total_responses'],
                'avg_token_length': result['avg_token_length'],
                'accuracy': result['accuracy']
            }
            if compute_ngrams:
                step_responses = responses_by_step.get(step, [])
                for n in ngram_sizes:
                    # 原有指标：平均每条 response 的 distinct count/ratio
                   
                    step_result[f'distinct_{n}gram_count'] = result.get(f'distinct_{n}gram_count')
                    step_result[f'distinct_{n}gram_ratio'] = result.get(f'distinct_{n}gram_ratio')
                    
                    # step_union_distinct: 当前 step 覆盖到的 distinct n-gram 数（以全局“所有 step 的 union 词表”为参照；计数上等价于本 step 内跨样本 union 去重后的 distinct）
                    step_hashes, step_total_ngrams = compute_corpus_ngram_hashes(step_responses, n=n)
                    step_union_distinct = len(step_hashes)
                    # 兼容命名：你说的 step_global_distinct（当前 step 在 across-step 统计口径下的 distinct）
                    step_result[f'step_global_distinct_{n}gram_count'] = step_union_distinct
                    # 保存全局 distinct（全实验 across 所有 step union 去重；每个 step 行里都是同一个常数）
                    step_result[f'global_distinct_{n}gram_count'] = global_distinct_counts.get(n, 0)

                    global_total_ngrams = global_total_ngram_counts.get(n, 0)
                    step_result[f'step_global_distinct_{n}gram_ratio'] = (step_union_distinct / global_total_ngrams) if global_total_ngrams > 0 else 0.0
                    step_result[f'step_total_{n}gram_count'] = step_total_ngrams
                    step_result[f'global_total_{n}gram_count'] = global_total_ngrams                 
            step_results.append(step_result)
        if verbose:
            print()
    
    # 按step排序
    step_results.sort(key=lambda x: x['step'])
    
    return step_results

def get_experiment_name(directory_path):
    """
    从目录路径中提取实验名称
    """
    # 获取倒数第二级目录名作为实验名
    parts = directory_path.rstrip('/').split('/')
    if len(parts) >= 2:
        return parts[-2]
    return os.path.basename(directory_path.rstrip('/'))

def plot_multiple_experiments(all_experiments_results, output_path='repetition_analysis.png', plot_ngrams=False, ngram_sizes=[3, 5], plot_length=False, plot_accuracy=False):
    """
    按照用户新需求绘制图表：
    1. 保存一张总的重复率对比图
    2. 为每个实验的每个 n-gram 保存一张包含长度和精度曲线的独立图
    """
    if not all_experiments_results:
        print("没有数据可以绘图")
        return

    # 获取输出目录
    output_dir = os.path.dirname(output_path)
    if not output_dir:
        output_dir = "."
    base_name = os.path.splitext(os.path.basename(output_path))[0]

    # 定义颜色和标记样式
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    markers = ['o', 's', '^', 'D', 'v', '<']

    # --- 1. 绘制总的重复率对比图 ---
    plt.figure(figsize=(12, 7))
    for idx, (exp_name, step_results) in enumerate(all_experiments_results.items()):
        steps = [r['step'] for r in step_results]
        repetition_rates = [r['repetition_rate'] for r in step_results]
        color = colors[idx % len(colors)]
        marker = markers[idx % len(markers)]
        
        plt.plot(steps, repetition_rates, marker=marker, linewidth=2, markersize=8, 
                 label=exp_name, color=color, alpha=0.8)

    plt.xlabel('Training Step', fontsize=13, fontweight='bold')
    plt.ylabel('Average Repetition Ratio', fontsize=13, fontweight='bold')
    plt.title('Comparison of Average Repetition Ratio', fontsize=15, fontweight='bold', pad=20)
    plt.legend(fontsize=10, loc='best')
    plt.grid(True, alpha=0.3, linestyle='--')
    
    ratio_output = os.path.join(output_dir, f"{base_name}_repetition_ratio.png")
    plt.tight_layout()
    plt.savefig(ratio_output, dpi=300)
    plt.close()
    print(f"重复率对比图已保存到: {ratio_output}")

    # --- 新增：为每个 n-gram 绘制总的 distinct ratio 对比图 ---
    if plot_ngrams:
        for n in ngram_sizes:
            plt.figure(figsize=(12, 7))
            for idx, (exp_name, step_results) in enumerate(all_experiments_results.items()):
                steps = [r['step'] for r in step_results]
                distinct_ratios = [r.get(f'distinct_{n}gram_ratio', 0) for r in step_results]
                color = colors[idx % len(colors)]
                marker = markers[idx % len(markers)]
                
                plt.plot(steps, distinct_ratios, marker=marker, linewidth=2, markersize=8, 
                         label=exp_name, color=color, alpha=0.8)

            plt.xlabel('Training Step', fontsize=13, fontweight='bold')
            plt.ylabel(f'Average Distinct {n}-gram Ratio', fontsize=13, fontweight='bold')
            plt.title(f'Comparison of Average Distinct {n}-gram Ratio', fontsize=15, fontweight='bold', pad=20)
            plt.legend(fontsize=10, loc='best')
            plt.grid(True, alpha=0.3, linestyle='--')
            
            ratio_output = os.path.join(output_dir, f"{base_name}_{n}gram_distinct_ratio.png")
            plt.tight_layout()
            plt.savefig(ratio_output, dpi=300)
            plt.close()
            print(f"Distinct {n}-gram 比例对比图已保存到: {ratio_output}")
        
    # --- 2. 为每个实验生成独立的 n-gram + 长度 + 精度图 ---
    for exp_idx, (exp_name, step_results) in enumerate(all_experiments_results.items()):
        steps = [r['step'] for r in step_results]
        token_lengths = [r.get('avg_token_length', 0) for r in step_results]
        accuracies = [r.get('accuracy', 0) for r in step_results]
        
        if plot_ngrams:
            for n in ngram_sizes:
                distinct_counts = [r.get(f'distinct_{n}gram_count', 0) for r in step_results]
                distinct_ratios = [r.get(f'distinct_{n}gram_ratio', 0) for r in step_results]
                # 创建新图，带双/三/四 Y 轴
                fig, ax1 = plt.subplots(figsize=(12, 7))
                
                # 1. 绘制 n-gram Ratio (左轴) - 优先绘制 Ratio，因为它比较重要且范围在 0-1
                color_ratio = 'forestgreen'
                ax1.set_xlabel('Training Step', fontsize=12, fontweight='bold')
                ax1.set_ylabel(f'Ratio (Avg-per-response)', color=color_ratio, fontsize=12, fontweight='bold')
                lns1 = ax1.plot(steps, distinct_ratios, marker='o', color=color_ratio, linewidth=2, label=f'Avg {n}-gram Ratio')
                ax1.tick_params(axis='y', labelcolor=color_ratio)
                ax1.grid(True, alpha=0.2)

                # 2. 绘制 n-gram Count (右轴1)
                ax_count = ax1.twinx()
                color_count = 'lime'
                ax_count.set_ylabel(f'Average Distinct {n}-gram Count', color=color_count, fontsize=12, fontweight='bold')
                lns_count = ax_count.plot(steps, distinct_counts, marker='h', color=color_count, linewidth=2, alpha=0.5, linestyle=':', label=f'{n}-gram Count')
                ax_count.tick_params(axis='y', labelcolor=color_count)
                lns = lns1 + lns_count

                # 3. 绘制 Token Length (右轴2)
                if plot_length:
                    ax2 = ax1.twinx()
                    # 将轴向右偏移
                    ax2.spines['right'].set_position(('outward', 60))
                    color_len = 'purple'
                    ax2.set_ylabel('Average Token Length', color=color_len, fontsize=12, fontweight='bold')
                    lns2 = ax2.plot(steps, token_lengths, marker='s', color=color_len, linewidth=2, alpha=0.6, linestyle='--', label='Token Length')
                    ax2.tick_params(axis='y', labelcolor=color_len)
                    lns += lns2

                # 4. 绘制 Accuracy (右轴3 - 偏移)
                if plot_accuracy:
                    ax3 = ax1.twinx()
                    # 将轴进一步向右偏移
                    ax3.spines['right'].set_position(('outward', 120))
                    color_acc = 'red'
                    ax3.set_ylabel('Macro Average Accuracy', color=color_acc, fontsize=12, fontweight='bold')
                    lns3 = ax3.plot(steps, accuracies, marker='*', color=color_acc, linewidth=2, alpha=0.8, label='Accuracy')
                    ax3.tick_params(axis='y', labelcolor=color_acc)
                    # 动态设置 y 轴范围，以便看清波动
                    if accuracies:
                        min_acc, max_acc = min(accuracies), max(accuracies)
                        margin = (max_acc - min_acc) * 0.1 if max_acc > min_acc else 0.05
                        ax3.set_ylim(max(0, min_acc - margin), min(1.0, max_acc + margin))
                    lns += lns3

                plt.title(f'{exp_name}: {n}-gram Metrics', fontsize=14, fontweight='bold', pad=15)
                
                # 合并所有轴的图例
                labs = [l.get_label() for l in lns]
                ax1.legend(lns, labs, loc='upper left')

                # 生成文件名
                sanitized_exp_name = exp_name.replace(" ", "_").replace("/", "_")
                ngram_output = os.path.join(output_dir, f"{base_name}_{sanitized_exp_name}_{n}gram_metrics.png")
                plt.tight_layout()
                plt.savefig(ngram_output, dpi=300)
                plt.close()
                print(f"实验图表已保存: {ngram_output}")

def plot_distribution_divergence(comparison_results, output_path='divergence_comparison.png', ngram_sizes=[3, 5], dir1_name='Dir1', dir2_name='Dir2', plot_entropies=True):
    """
    绘制两个目录的 n-gram 分布散度对比图
    
    Args:
        comparison_results: 对比结果列表
        output_path: 输出图表路径
        ngram_sizes: n-gram 大小列表
        dir1_name: 第一个目录的名称
        dir2_name: 第二个目录的名称
        plot_entropies: 是否绘制熵的变化图
    """
    if not comparison_results:
        print("没有数据可以绘图")
        return
    
    num_ngrams = len(ngram_sizes)
    
    # 如果绘制熵，创建 2 行子图（第一行是散度，第二行是熵）
    if plot_entropies:
        fig, axes = plt.subplots(2, num_ngrams, figsize=(7 * num_ngrams, 12))
        if num_ngrams == 1:
            axes = [[axes[0]], [axes[1]]]
    else:
        fig, axes_row = plt.subplots(1, num_ngrams, figsize=(7 * num_ngrams, 6))
        if num_ngrams == 1:
            axes_row = [axes_row]
        axes = [axes_row]
    
    steps = [r['step'] for r in comparison_results]
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
    
    # 绘制散度
    for idx, n in enumerate(ngram_sizes):
        divergences = [r[f'{n}gram_divergence'] for r in comparison_results]
        
        axes[0][idx].plot(steps, divergences, marker='o', linewidth=2, markersize=8, color=colors[0])
        axes[0][idx].set_xlabel('Training Step', fontsize=13, fontweight='bold')
        axes[0][idx].set_ylabel(f'{n}-gram JS Divergence', fontsize=13, fontweight='bold')
        axes[0][idx].set_title(f'{n}-gram Distribution Divergence)', fontsize=14, fontweight='bold', pad=15)
        axes[0][idx].grid(True, alpha=0.3, linestyle='--')
        
        # 添加数值标签
        if len(steps) <= 15:
            for step, div in zip(steps, divergences):
                axes[0][idx].annotate(f'{div:.4f}', 
                            xy=(step, div), 
                            xytext=(0, 8),
                            textcoords='offset points',
                            ha='center',
                            fontsize=8,
                            alpha=0.7)
    
    # 绘制熵
    if plot_entropies:
        for idx, n in enumerate(ngram_sizes):
            entropies1 = [r[f'{n}gram_entropy1'] for r in comparison_results]
            entropies2 = [r[f'{n}gram_entropy2'] for r in comparison_results]
            
            axes[1][idx].plot(steps, entropies1, marker='s', linewidth=2, markersize=8, 
                            color=colors[2], label=f'{dir1_name}', alpha=0.8)
            axes[1][idx].plot(steps, entropies2, marker='^', linewidth=2, markersize=8, 
                            color=colors[3], label=f'{dir2_name}', alpha=0.8)
            
            axes[1][idx].set_xlabel('Training Step', fontsize=13, fontweight='bold')
            axes[1][idx].set_ylabel(f'{n}-gram Entropy', fontsize=13, fontweight='bold')
            axes[1][idx].set_title(f'{n}-gram Distribution Entropy', fontsize=14, fontweight='bold', pad=15)
            axes[1][idx].legend(fontsize=10, loc='best', framealpha=0.9)
            axes[1][idx].grid(True, alpha=0.3, linestyle='--')
            
            # 添加数值标签（熵值通常比较大，所以可以少标一些）
            if len(steps) <= 10:
                for step, ent1, ent2 in zip(steps, entropies1, entropies2):
                    axes[1][idx].annotate(f'{ent1:.2f}', 
                                xy=(step, ent1), 
                                xytext=(0, 8),
                                textcoords='offset points',
                                ha='center',
                                fontsize=7,
                                alpha=0.6)
                    axes[1][idx].annotate(f'{ent2:.2f}', 
                                xy=(step, ent2), 
                                xytext=(0, -12),
                                textcoords='offset points',
                                ha='center',
                                fontsize=7,
                                alpha=0.6)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n图表已保存到: {output_path}")
    plt.close()

def plot_repetition_rates(step_results, output_path='repetition_analysis.png', plot_ngrams=False, ngram_sizes=[3, 5], plot_length=False, plot_accuracy=False):
    """
    绘制单个实验的重复率图表
    """
    if not step_results:
        print("没有数据可以绘图")
        return
    
    num_plots = 1
    if plot_ngrams:
        num_plots += len(ngram_sizes) * 2  # 每个 n-gram: Avg Count, Avg Ratio
    if plot_length:
        num_plots += 1
    if plot_accuracy:
        num_plots += 1
        
    fig, axes = plt.subplots(1, num_plots, figsize=(7 * num_plots, 6))
    if num_plots == 1:
        axes = [axes]
    
    steps = [r['step'] for r in step_results]
    repetition_rates = [r['repetition_rate'] for r in step_results]
    
    axes[0].plot(steps, repetition_rates, marker='o', linewidth=2, markersize=8)
    axes[0].set_xlabel('Training Step', fontsize=12)
    axes[0].set_ylabel('Average Repetition Ratio', fontsize=12)
    axes[0].set_title('Average Repetition Ratio', fontsize=14)
    axes[0].grid(True, alpha=0.3)
    
    current_ax_idx = 1
    
    if plot_ngrams:
        for i, n in enumerate(ngram_sizes):
            # 绘制 Count
            distinct_counts = [r.get(f'distinct_{n}gram_count', 0) for r in step_results]
            axes[current_ax_idx].plot(steps, distinct_counts, marker='o', linewidth=2, markersize=8, color='green')
            axes[current_ax_idx].set_xlabel('Training Step', fontsize=12)
            axes[current_ax_idx].set_ylabel(f'Average Distinct {n}-gram Count', fontsize=12)
            axes[current_ax_idx].set_title(f'Average Distinct {n}-gram Count', fontsize=14)
            axes[current_ax_idx].grid(True, alpha=0.3)
            current_ax_idx += 1

            # 绘制 Ratio
            distinct_ratios = [r.get(f'distinct_{n}gram_ratio', 0) for r in step_results]
            axes[current_ax_idx].plot(steps, distinct_ratios, marker='s', linewidth=2, markersize=8, color='darkgreen')
            axes[current_ax_idx].set_xlabel('Training Step', fontsize=12)
            axes[current_ax_idx].set_ylabel(f'Average Distinct {n}-gram Ratio', fontsize=12)
            axes[current_ax_idx].set_title(f'Average Distinct {n}-gram Ratio', fontsize=14)
            axes[current_ax_idx].grid(True, alpha=0.3)
            current_ax_idx += 1
            
    if plot_length:
        token_lengths = [r.get('avg_token_length', 0) for r in step_results]
        axes[current_ax_idx].plot(steps, token_lengths, marker='o', linewidth=2, markersize=8, color='purple')
        axes[current_ax_idx].set_xlabel('Training Step', fontsize=12)
        axes[current_ax_idx].set_ylabel('Average Token Length', fontsize=12)
        axes[current_ax_idx].set_title('Average Response Length', fontsize=14)
        axes[current_ax_idx].grid(True, alpha=0.3)
        current_ax_idx += 1

    if plot_accuracy:
        accuracies = [r.get('accuracy', 0) for r in step_results]
        axes[current_ax_idx].plot(steps, accuracies, marker='o', linewidth=2, markersize=8, color='red')
        axes[current_ax_idx].set_xlabel('Training Step', fontsize=12)
        axes[current_ax_idx].set_ylabel('Macro Average Accuracy', fontsize=12)
        axes[current_ax_idx].set_title('Macro Average Accuracy', fontsize=14)
        axes[current_ax_idx].grid(True, alpha=0.3)
        # 动态设置 y 轴范围，以便看清波动
        if accuracies:
            min_acc, max_acc = min(accuracies), max(accuracies)
            margin = (max_acc - min_acc) * 0.1 if max_acc > min_acc else 0.05
            axes[current_ax_idx].set_ylim(max(0, min_acc - margin), min(1.0, max_acc + margin))
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n图表已保存到: {output_path}")
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='分析JSONL文件中response的重复情况')
    parser.add_argument('--file', '-f', type=str, 
                       help='单个JSONL文件路径')
    parser.add_argument('--directory', '-d', type=str, action='append',
                       help='包含多个step文件的目录路径（可以多次使用此参数来指定多个目录）')
    parser.add_argument('--compare-dirs', type=str, nargs=2, metavar=('DIR1', 'DIR2'),
                       help='对比两个目录中对应样本的n-gram分布差异')
    parser.add_argument('--pattern', '-p', type=str, default='*_16384.jsonl',
                       help='文件匹配模式 (默认: *_16384.jsonl)')
    parser.add_argument('--window_size', '-w', type=int, default=10,
                       help='滑动窗口大小 (默认: 10)')
    parser.add_argument('--max_repetitions_limit', '-m', type=int, default=10,
                       help='重复次数阈值，超过此值的窗口被认为是重复 (默认: 6)')
    parser.add_argument('--output', '-o', type=str, default='repetition_analysis.png',
                       help='输出图表文件路径 (默认: repetition_analysis.png)')
    parser.add_argument('--show-examples', '-s', action='store_true',
                       help='显示重复文本的具体示例')
    parser.add_argument('--compute-ngrams', '-n', action='store_true',
                       help='计算 n-gram 的 distinct 个数（由 --ngram-sizes 指定，用于衡量 exploration）')
    parser.add_argument('--ngram-sizes', type=int, nargs='+', default=[10],
                       help='用于计算 distinct 个数和分布对比的 n-gram 大小 (默认: 5 10 15)')
    parser.add_argument('--divergence-method', type=str, default='js', choices=['js', 'kl'],
                       help='散度计算方法: js (Jensen-Shannon) 或 kl (KL divergence) (默认: js)')
    parser.add_argument('--plot-entropies', action='store_true',
                       help='在对比目录时绘制分布熵的变化图')
    parser.add_argument('--tokenizer', type=str,
                       help='Tokenizer 路径，用于计算 token 长度曲线')
    parser.add_argument('--save-data', type=str,
                       help='将分析结果保存到 JSON 文件')
    parser.add_argument('--load-data', type=str,
                       help='从保存的 JSON 文件加载数据直接画图')
    parser.add_argument('--plot-accuracy', action='store_true',
                       help='在图表中绘制 Accuracy 曲线')
    parser.add_argument('--no-plot', action='store_true',
                       help='禁用绘图功能')
    parser.add_argument('--step-interval', type=int, default=10,
                       help='批量分析时按 step 编号抽样的间隔（默认每隔10步取一个；设为1表示不抽样）')
    parser.add_argument('--step-offset', type=int, default=0,
                       help='抽样的起始 offset：选择满足 (step-offset) % interval == 0 的 step（默认0）')
    
    args = parser.parse_args()
    
    # 初始化 Tokenizer
    tokenizer = None
    if args.tokenizer:
        if AutoTokenizer is None:
            print("错误: 无法导入 transformers，请安装 transformers 以使用 tokenizer 功能")
        else:
            print(f"正在加载 tokenizer: {args.tokenizer}")
            tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
    
    # 初始化 all_experiments_results
    all_experiments_results = {}

    if args.load_data:
        print(f"正在从文件加载数据: {args.load_data}")
        try:
            with open(args.load_data, 'r', encoding='utf-8') as f:
                all_experiments_results = json.load(f)
        except Exception as e:
            print(f"加载数据失败: {e}")
            sys.exit(1)
    
    if args.compare_dirs:
        # 对比两个目录的分布差异模式
        dir1, dir2 = args.compare_dirs
        
        print("="*60)
        print("目录对比模式: n-gram 分布差异 analysis")
        print("="*60)
        print(f"目录1: {dir1}")
        print(f"目录2: {dir2}")
        print(f"文件模式: {args.pattern}")
        print(f"n-gram 大小: {args.ngram_sizes}")
        print(f"散度方法: {args.divergence_method.upper()}")
        print("="*60)
        print()
        
        comparison_results = compare_two_directories(
            dir1, dir2, 
            pattern=args.pattern,
            ngram_sizes=args.ngram_sizes,
            method=args.divergence_method,
            verbose=True
        )
        
        if comparison_results:
            # ... (rest of comparison_results logic remains the same)
            # 打印统计表格 - 散度
            print("\n" + "="*80)
            print("对比结果汇总 - 分布散度")
            print("="*80)
            
            header = f"{'Step':<10} "
            for n in args.ngram_sizes:
                header += f"{n}-gram 散度{' '*5}"
            print(header)
            print("-" * 80)
            
            for r in comparison_results:
                row = f"{r['step']:<10} "
                for n in args.ngram_sizes:
                    row += f"{r[f'{n}gram_divergence']:<15.6f} "
                print(row)
            
            print("-" * 80)
            avg_row = f"{'平均':<10} "
            for n in args.ngram_sizes:
                avg_div = np.mean([r[f'{n}gram_divergence'] for r in comparison_results])
                avg_row += f"{avg_div:<15.6f} "
            print(avg_row)
            print("="*80)
            
            print("\n" + "="*100)
            print("对比结果汇总 - 分布熵")
            print("="*100)
            
            header = f"{'Step':<10} "
            for n in args.ngram_sizes:
                header += f"{n}-gram 熵1{' '*5}{n}-gram 熵2{' '*5}"
            print(header)
            print("-" * 100)
            
            for r in comparison_results:
                row = f"{r['step']:<10} "
                for n in args.ngram_sizes:
                    row += f"{r[f'{n}gram_entropy1']:<15.6f} {r[f'{n}gram_entropy2']:<15.6f} "
                print(row)
            
            print("-" * 100)
            avg_row = f"{'平均':<10} "
            for n in args.ngram_sizes:
                avg_ent1 = np.mean([r[f'{n}gram_entropy1'] for r in comparison_results])
                avg_ent2 = np.mean([r[f'{n}gram_entropy2'] for r in comparison_results])
                avg_row += f"{avg_ent1:<15.6f} {avg_ent2:<15.6f} "
            print(avg_row)
            print("="*100)
            
            dir1_name = get_experiment_name(dir1)
            dir2_name = get_experiment_name(dir2)
            if not args.no_plot:
                plot_distribution_divergence(
                    comparison_results, 
                    output_path=args.output,
                    ngram_sizes=args.ngram_sizes,
                    dir1_name=dir1_name,
                    dir2_name=dir2_name,
                    plot_entropies=True
                )
    
    elif args.directory and not args.load_data:
        # 批量分析模式（可能是多个实验）
        print("="*60)
        print("批量分析模式")
        print("="*60)
        print(f"模式: {args.pattern}")
        print(f"窗口大小: {args.window_size}")
        print(f"重复阈值: {args.max_repetitions_limit}")
        print(f"计算n-gram: {args.compute_ngrams}")
        print(f"实验数量: {len(args.directory)}")
        print("="*60)
        print()
        
        for directory in args.directory:
            exp_name = get_experiment_name(directory)
            print(f"\n{'='*60}")
            print(f"分析实验: {exp_name}")
            print(f"目录: {directory}")
            print(f"{'='*60}")
            
            step_results = analyze_multiple_steps(directory, args.pattern, args.window_size, args.max_repetitions_limit, verbose=True, show_examples=args.show_examples, compute_ngrams=args.compute_ngrams, ngram_sizes=args.ngram_sizes, tokenizer=tokenizer)
            
            if step_results:
                all_experiments_results[exp_name] = step_results
                
                # 打印该实验的统计
                print(f"\n{exp_name} - 统计结果:")
                header = f"{'Step':<10} {'文件名':<25} {'重复比例':<15} "
                if args.compute_ngrams:
                    for n in args.ngram_sizes:
                        header += f"{n}-gram Count{' '*(15-len(str(n))-11)} "
                        header += f"{n}-gram Ratio{' '*(15-len(str(n))-11)} "
                        header += f"{n}-gram StepGlobalDistinct{' '*(24-len(str(n))-21)} "
                        header += f"{n}-gram GlobalDistinct{' '*(20-len(str(n))-17)} "
                        header += f"{n}-gram StepD/GlobalTotal{' '*(24-len(str(n))-20)} "
                        header += f"{n}-gram StepTotal{' '*(18-len(str(n))-13)} "
                        header += f"{n}-gram TotalRatio{' '*(18-len(str(n))-14)} "
                if tokenizer or any('avg_token_length' in r for r in step_results):
                    header += f"{'Token长度':<12} "
                header += f"{'Accuracy':<10} "
                header += f"{'总样本数':<10}"
                print(header)
                print("-" * len(header))
                for r in step_results:
                    row = f"{r['step']:<10} {r['file']:<25} {r['repetition_rate']:<15.4f} "
                    if args.compute_ngrams:
                        for n in args.ngram_sizes:
                            row += f"{r.get(f'distinct_{n}gram_count', 0):<15.2f} "
                            row += f"{r.get(f'distinct_{n}gram_ratio', 0):<15.4f} "
                            row += f"{r.get(f'step_global_distinct_{n}gram_count', 0):<24.2f} "
                            row += f"{r.get(f'global_distinct_{n}gram_count', 0):<20.2f} "
                            row += f"{r.get(f'step_global_distinct_{n}gram_ratio', 0):<24.6f} "
                            row += f"{r.get(f'step_total_{n}gram_count', 0):<18.2f} "
                            row += f"{r.get(f'global_total_{n}gram_ratio', 0):<18.4f} "
                    if tokenizer or 'avg_token_length' in r:
                        row += f"{r.get('avg_token_length', 0):<12.2f} "
                    row += f"{r.get('accuracy', 0):<10.4f} "
                    row += f"{r['total_responses']:<10}"
                    print(row)
        
        if args.save_data and all_experiments_results:
            print(f"\n正在将分析结果保存到: {args.save_data}")
            with open(args.save_data, 'w', encoding='utf-8') as f:
                json.dump(all_experiments_results, f, indent=2, ensure_ascii=False)

    # 绘图逻辑（无论是刚分析的还是加载的）
    if all_experiments_results and (args.directory or args.load_data) and not args.no_plot:
        print("\n" + "="*60)
        print("汇总对比与绘图")
        print("="*60)
        for exp_name, results in all_experiments_results.items():
            avg_rate = np.mean([r['repetition_rate'] for r in results])
            avg_acc = np.mean([r.get('accuracy', 0) for r in results])
            print(f"{exp_name}: 平均重复比例 = {avg_rate:.4f} ({avg_rate*100:.2f}%)")
            print(f"{exp_name}: 平均Accuracy = {avg_acc:.4f}")
            if tokenizer or any('avg_token_length' in r for r in results):
                avg_len = np.mean([r.get('avg_token_length', 0) for r in results])
                print(f"{exp_name}: 平均Token长度 = {avg_len:.2f}")
            if args.compute_ngrams:
                for n in args.ngram_sizes:
                    avg_ngram = np.mean([r.get(f'distinct_{n}gram_count', 0) for r in results])
                    avg_ratio = np.mean([r.get(f'distinct_{n}gram_ratio', 0) for r in results])
                    print(f"{exp_name}: 平均{n}-gram distinct个数 = {avg_ngram:.2f}, 比例 = {avg_ratio:.4f}")
                    avg_stepd_globaltotal = np.mean([r.get(f'step_global_distinct_{n}gram_ratio', 0) for r in results])
                    global_total = results[0].get(f'global_total_{n}gram_count', 0) if results else 0
                    global_distinct = results[0].get(f'global_distinct_{n}gram_count', 0) if results else 0
                    print(f"{exp_name}: 平均(step_global_distinct/global_total) = {avg_stepd_globaltotal:.6f} (global distinct={global_distinct}, global total={global_total})")
        
        if len(all_experiments_results) > 1:
            plot_multiple_experiments(all_experiments_results, args.output, plot_ngrams=args.compute_ngrams, ngram_sizes=args.ngram_sizes, plot_length=(tokenizer is not None or any('avg_token_length' in r for res in all_experiments_results.values() for r in res)), plot_accuracy=args.plot_accuracy)
        else:
            # 只有一个实验，使用单实验绘图
            plot_repetition_rates(list(all_experiments_results.values())[0], args.output, plot_ngrams=args.compute_ngrams, ngram_sizes=args.ngram_sizes, plot_length=(tokenizer is not None or any('avg_token_length' in r for r in list(all_experiments_results.values())[0])), plot_accuracy=args.plot_accuracy)
    
    elif args.file:
        # 单文件分析模式
        print("="*60)
        print("单文件分析模式")
        print("="*60)
        result = analyze_repetition(args.file, args.window_size, args.max_repetitions_limit, verbose=True, show_examples=args.show_examples, compute_ngrams=args.compute_ngrams, ngram_sizes=args.ngram_sizes, tokenizer=tokenizer)
        
        if result:
            print("\n" + "="*60)
            print("分析结果")
            print("="*60)
            print(f"总response数量: {result['total_responses']}")
            print(f"平均重复比例: {result['repetition_rate']:.4f} ({result['repetition_rate']*100:.2f}%)")
            print(f"Accuracy (Macro Avg): {result['accuracy']:.4f}")
            if tokenizer:
                print(f"平均Token长度: {result.get('avg_token_length', 0):.2f}")
            if args.compute_ngrams:
                for n in args.ngram_sizes:
                    print(f"平均{n}-gram distinct个数: {result.get(f'distinct_{n}gram_count', 0):.2f}")
                    print(f"平均{n}-gram distinct比例: {result.get(f'distinct_{n}gram_ratio', 0):.4f}")
            if args.show_examples and result.get('repetition_examples'):
                print(f"检测到重复的样例数: {len(result['repetition_examples'])}")
            print("="*60)
    
    else:
        if not args.load_data:
            parser.print_help()
            print("\n示例用法:")
            print("  单文件分析:")
            print("    python analysis_repetition_penalty.py -f /path/to/file.jsonl")
            print("\n  批量分析并保存数据:")
            print("    python analysis_repetition_penalty.py -d ./exp1 -d ./exp2 --save-data results.json --plot-accuracy")
            print("\n  从保存的数据直接绘图:")
            print("    python analysis_repetition_penalty.py --load-data results.json --plot-accuracy --compute-ngrams -o comparison.png")
