#!/usr/bin/env python3
"""
训练数据语义重复度分析脚本（按step）
分析训练过程中收集的生成数据，按step统计平均语义重复度并画图
横坐标：step
纵坐标：平均语义重复度
"""

import json
import argparse
import os
import re
from typing import List, Dict, Tuple
import numpy as np
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from collections import defaultdict
import matplotlib.pyplot as plt
import multiprocessing
from multiprocessing import Process, Queue
import math

# 设置multiprocessing启动方法为'spawn'，以支持CUDA多进程
if __name__ == '__main__':
    multiprocessing.set_start_method('spawn', force=True)


def batch_split_texts_by_tokens(texts: List[str], tokenizer, chunk_size: int = 512) -> Tuple[List[str], List[Tuple[int, int]]]:
    """
    批量将多个文本按照token数量分割成chunks（加速版本）
    
    Args:
        texts: 输入文本列表
        tokenizer: 分词器
        chunk_size: 每个chunk的token数量
    
    Returns:
        (所有chunks列表, 每个文本的chunk边界列表)
    """
    all_chunks = []
    chunk_boundaries = []
    
    # 批量tokenize所有文本（大幅加速）
    all_tokens_list = tokenizer(
        texts,
        add_special_tokens=False,
        truncation=False,
        padding=False,
        return_attention_mask=False,
    )["input_ids"]
    
    # 为每个文本分割chunks
    for tokens in all_tokens_list:
        start_idx = len(all_chunks)
        
        if len(tokens) == 0:
            # 空文本
            chunk_boundaries.append((start_idx, start_idx))
            continue
        
        # 分割tokens
        chunk_tokens_list = []
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens[i:i + chunk_size]
            chunk_tokens_list.append(chunk_tokens)
        
        # 批量decode chunks（比逐个decode快）
        if chunk_tokens_list:
            chunk_texts = tokenizer.batch_decode(chunk_tokens_list, skip_special_tokens=True)
            all_chunks.extend(chunk_texts)
        
        end_idx = len(all_chunks)
        chunk_boundaries.append((start_idx, end_idx))
    
    return all_chunks, chunk_boundaries


def batch_split_texts_by_text(texts: List[str], tokenizer=None) -> Tuple[List[str], List[Tuple[int, int]]]:
    """
    批量将多个文本按照文本内容分割成chunks（按照空格和下划线切分）
    切分方式：先按空格分割，然后对每个segment按下划线分割（与repetition.py中的方式一致）
    每个word作为一个独立的chunk
    
    Args:
        texts: 输入文本列表
        tokenizer: 分词器（可选，当前未使用）
    
    Returns:
        (所有chunks列表, 每个文本的chunk边界列表)
        chunk_boundaries: [(start_idx, end_idx), ...] 每个元组表示该文本在all_chunks中的起始和结束索引（左闭右开）
    """
    all_chunks = []
    chunk_boundaries = []
    
    for text in texts:
        start_idx = len(all_chunks)
        
        if not text or not text.strip():
            # 空文本
            chunk_boundaries.append((start_idx, start_idx))
            continue
        
        # 按照repetition.py的方式切分：先按空格，再按下划线
        words = []
        for segment in text.split():
            words.extend(segment.split('_'))
        
        # 过滤空字符串
        words = [word for word in words if word]
        
        if len(words) == 0:
            chunk_boundaries.append((start_idx, start_idx))
            continue
        
        # 每个word作为一个chunk
        all_chunks.extend(words)
        
        end_idx = len(all_chunks)
        chunk_boundaries.append((start_idx, end_idx))
    
    return all_chunks, chunk_boundaries


def calculate_repetition_metrics(similarity_matrix: np.ndarray) -> Dict[str, float]:
    """
    基于相似度矩阵计算重复度指标
    
    Args:
        similarity_matrix: 相似度矩阵
    
    Returns:
        包含各种重复度指标的字典
    """
    if similarity_matrix.size == 0:
        return {
            "avg_similarity": 0.0,
            "max_similarity": 0.0,
            "repetition_ratio_0.8": 0.0,
            "repetition_ratio_0.9": 0.0,
            "num_chunks": 0,
        }
    
    n = similarity_matrix.shape[0]
    
    if n == 1:
        return {
            "avg_similarity": 1.0,
            "max_similarity": 1.0,
            "repetition_ratio_0.8": 0.0,
            "repetition_ratio_0.9": 0.0,
            "num_chunks": 1,
        }
    
    # 获取上三角矩阵（不包括对角线），避免重复计算和自相似度
    upper_triangle_indices = np.triu_indices(n, k=1)
    similarities = similarity_matrix[upper_triangle_indices]
    
    # 计算平均相似度（排除自身）
    avg_similarity = np.mean(similarities)
    
    # 计算最大相似度（排除自身）
    max_similarity = np.max(similarities)
    
    # 计算高相似度对的比例
    repetition_ratio_0_8 = np.sum(similarities >= 0.8) / len(similarities)
    repetition_ratio_0_9 = np.sum(similarities >= 0.9) / len(similarities)
    
    return {
        "avg_similarity": float(avg_similarity),
        "max_similarity": float(max_similarity),
        "repetition_ratio_0.8": float(repetition_ratio_0_8),
        "repetition_ratio_0.9": float(repetition_ratio_0_9),
        "num_chunks": n,
    }


def analyze_step_file(
    file_path: str,
    tokenizer,
    model: SentenceTransformer,
    chunk_size: int = 512,
    batch_size: int = 32,
    response_key: str = "response",
    accuracy_key: str = "accuracies",
) -> Dict[str, float]:
    """
    分析单个step文件的语义重复度
    
    Args:
        file_path: JSONL文件路径
        tokenizer: 分词器
        model: embedding模型
        chunk_size: chunk大小
        batch_size: 编码时的batch size
        response_key: response字段的键名
    
    Returns:
        该step的平均语义重复度指标
    """
    # 读取文件
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    if len(lines) == 0:
        return {
            "avg_similarity": 0.0,
            "max_similarity": 0.0,
            "repetition_ratio_0.8": 0.0,
            "repetition_ratio_0.9": 0.0,
            "num_samples": 0,
        }
    
    # 解析数据
    all_data = [json.loads(line) for line in lines]
    
    # 展平response与accuracy，保持一一对应
    valid_responses: List[str] = []
    valid_accuracies: List[float] = []
    for data in all_data:
        responses = data.get(response_key, [])
        accuracies = data.get(accuracy_key, [])
        
        for resp, acc in zip(responses, accuracies):
            if isinstance(resp, str) and resp.strip():
                valid_responses.append(resp)
                valid_accuracies.append(acc)
    
    
    # 批量分割所有responses为chunks
    all_chunks, chunk_boundaries = batch_split_texts_by_text(valid_responses, tokenizer)
    
    if len(all_chunks) == 0:
        return {
            "avg_similarity": 0.0,
            "max_similarity": 0.0,
            "repetition_ratio_0.8": 0.0,
            "repetition_ratio_0.9": 0.0,
            "num_samples": len(all_data),
        }
    
    # 批量编码所有chunks
    all_embeddings = model.encode(
        all_chunks,
        convert_to_tensor=True,
        show_progress_bar=False,
        batch_size=batch_size
    )
    
    # 为每个response计算相似度指标
    all_avg_similarities = []
    all_max_similarities = []
    all_repetition_ratios_0_8 = []
    all_repetition_ratios_0_9 = []
    correct_avg_similarities = []
    incorrect_avg_similarities = []
    
    for idx in range(len(valid_responses)):
        start_idx, end_idx = chunk_boundaries[idx]
        
        if start_idx == end_idx:  # 空response或太短
            continue
        
        # 获取该response的embeddings
        response_embeddings = all_embeddings[start_idx:end_idx]
        
        # 计算相似度矩阵
        similarity_matrix = model.similarity(response_embeddings, response_embeddings)
        
        # 转换为numpy
        if torch.is_tensor(similarity_matrix):
            similarity_matrix = similarity_matrix.float().cpu().numpy()
        
        # 计算重复度指标
        metrics = calculate_repetition_metrics(similarity_matrix)
        
        all_avg_similarities.append(metrics["avg_similarity"])
        all_max_similarities.append(metrics["max_similarity"])
        all_repetition_ratios_0_8.append(metrics["repetition_ratio_0.8"])
        all_repetition_ratios_0_9.append(metrics["repetition_ratio_0.9"])
        
        # 按准确性分桶，不做额外计算
        acc_val = valid_accuracies[idx] if idx < len(valid_accuracies) else None
        if acc_val is None:
            pass
        elif acc_val > 0:
            correct_avg_similarities.append(metrics["avg_similarity"])
        else:
            incorrect_avg_similarities.append(metrics["avg_similarity"])
    
    # 清理GPU内存
    del all_embeddings
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # 计算平均值
    if len(all_avg_similarities) == 0:
        return {
            "avg_similarity": 0.0,
            "max_similarity": 0.0,
            "repetition_ratio_0.8": 0.0,
            "repetition_ratio_0.9": 0.0,
            "num_samples": len(all_data),
        }
    
    return {
        "avg_similarity": np.mean(all_avg_similarities),
        "max_similarity": np.mean(all_max_similarities),
        "repetition_ratio_0.8": np.mean(all_repetition_ratios_0_8),
        "repetition_ratio_0.9": np.mean(all_repetition_ratios_0_9),
        "num_samples": len(all_data),
        "valid_samples": len(all_avg_similarities),
        "avg_similarity_correct": float(np.mean(correct_avg_similarities)) if len(correct_avg_similarities) > 0 else 0.0,
        "avg_similarity_incorrect": float(np.mean(incorrect_avg_similarities)) if len(incorrect_avg_similarities) > 0 else 0.0,
        "num_correct": len(correct_avg_similarities),
        "num_incorrect": len(incorrect_avg_similarities),
    }


def worker_process(
    gpu_id: int,
    step_files: List[Tuple[int, str]],
    model_name: str,
    chunk_size: int,
    batch_size: int,
    response_key: str,
    accuracy_key: str,
    result_queue: Queue,
    start_step: int = None,
    end_step: int = None,
):
    """
    工作进程：在指定GPU上处理分配的step文件
    
    Args:
        gpu_id: GPU ID
        step_files: 要处理的step文件列表
        model_name: 模型名称
        chunk_size: chunk大小
        batch_size: batch大小
        response_key: response字段键名
        result_queue: 结果队列
    """
    try:
        # 在spawn模式下，每个进程都是全新的，需要重新导入
        import torch
        from sentence_transformers import SentenceTransformer
        from transformers import AutoTokenizer
        
        # 设置当前进程使用的GPU
        if torch.cuda.is_available():
            device = f"cuda:{gpu_id}"
            # 在spawn模式下，可以直接设置设备，不需要set_device
        else:
            device = "cpu"
        
        print(f"[GPU {gpu_id}] 开始处理 {len(step_files)} 个文件，使用设备: {device}")
        
        # 加载模型
        try:
            model = SentenceTransformer(
                model_name,
                model_kwargs={
                    "attn_implementation": "flash_attention_2",
                    "torch_dtype": torch.bfloat16,
                },
                tokenizer_kwargs={"padding_side": "left"},
                device=device,
            )
        except Exception as e:
            try:
                model = SentenceTransformer(
                    model_name,
                    model_kwargs={
                        "attn_implementation": "flash_attention_2",
                        "torch_dtype": torch.float16,
                    },
                    tokenizer_kwargs={"padding_side": "left"},
                    device=device,
                )
            except Exception:
                model = SentenceTransformer(model_name, device=device)
        
        # 加载tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # 处理分配的step文件
        for step_number, file_path in step_files:
            metrics = analyze_step_file(
                file_path,
                tokenizer,
                model,
                chunk_size=chunk_size,
                batch_size=batch_size,
                response_key=response_key,
                accuracy_key=accuracy_key,
            )
            
            result_queue.put({
                "step": step_number,
                "file": os.path.basename(file_path),
                **metrics
            })
        
        print(f"[GPU {gpu_id}] 完成处理")
        
    except Exception as e:
        print(f"[GPU {gpu_id}] 错误: {e}")
        import traceback
        traceback.print_exc()


def find_step_files(data_dir: str, pattern: str = r"step_(\d+)_traindata\.jsonl") -> List[Tuple[int, str]]:
    """
    查找所有step文件并提取step编号
    
    Args:
        data_dir: 数据目录
        pattern: 文件名匹配模式
    
    Returns:
        [(step_number, file_path), ...] 按step_number排序
    """
    step_files = []
    
    for filename in os.listdir(data_dir):
        match = re.match(pattern, filename)
        if match:
            step_number = int(match.group(1))
            file_path = os.path.join(data_dir, filename)
            step_files.append((step_number, file_path))
    
    # 按step编号排序
    step_files.sort(key=lambda x: x[0])
    
    return step_files


def analyze_training_data_by_step(
    data_dir: str,
    output_file: str,
    model_name: str = "Qwen/Qwen3-Embedding-0.6B",
    chunk_size: int = 512,
    batch_size: int = 32,
    response_key: str = "response",
    accuracy_key: str = "accuracies",
    file_pattern: str = r"step_(\d+)_traindata\.jsonl",
    max_files: int = None,
    start_step: int = None,
    end_step: int = None,
    gpu_id: int = 0,
    num_gpus: int = None,
): 
    """
    分析训练数据，按step统计平均语义重复度
    
    Args:
        data_dir: 训练数据目录（包含step_X_traindata.jsonl文件）
        output_file: 输出文件路径（JSON和图片）
        model_name: embedding模型名称
        chunk_size: chunk大小
        batch_size: 编码时的batch size
        response_key: response字段的键名
        file_pattern: 文件名匹配模式
        gpu_id: 使用的GPU ID（默认: 0，仅在单GPU模式下使用）
        num_gpus: 使用的GPU数量（None表示单GPU模式，使用gpu_id指定的GPU）
    """
    # 查找所有step文件
    print(f"查找step文件: {data_dir}")
    step_files = find_step_files(data_dir, file_pattern)
    if start_step is not None or end_step is not None:
        step_files = [
            (s, p)
            for s, p in step_files
            if (start_step is None or s >= start_step)
            and (end_step is None or s <= end_step)
        ]
    if max_files is not None and max_files > 0:
        step_files = step_files[:max_files]
    
    if len(step_files) == 0:
        print(f"错误：在 {data_dir} 中未找到匹配的文件（模式: {file_pattern}）")
        return
    
    print(f"找到 {len(step_files)} 个step文件")
    print(f"编码batch_size: {batch_size}")
    print(f"Chunk大小: {chunk_size} words\n")
    
    # 结果文件路径（提前计算，便于复用）
    json_output_file = output_file.replace('.png', '_split_by_text.json') if output_file.endswith('.png') else output_file + '_split_by_text.json'
    
    # 如果已有JSON结果，直接加载并绘图
    if os.path.exists(json_output_file):
        print(f"Found existing JSON result, load and plot only: {json_output_file}")
        with open(json_output_file, 'r', encoding='utf-8') as f:
            results = json.load(f)
        # 确保按step排序
        results.sort(key=lambda x: x["step"])
    else:
        # 检查是否使用多GPU
        if num_gpus is None or num_gpus == 1:
            # 单GPU模式
            print(f"使用单GPU模式 (GPU {gpu_id})")
            if torch.cuda.is_available():
                device = f"cuda:{gpu_id}"
                print(f"使用设备: {device}")
                print(f"GPU设备: {torch.cuda.get_device_name(gpu_id)}")
                print(f"GPU显存: {torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3:.2f} GB")
            else:
                device = "cpu"
                print(f"使用设备: {device}")
            
            # 加载模型
            print(f"加载模型: {model_name}")
            try:
                model = SentenceTransformer(
                    model_name,
                    model_kwargs={
                        "attn_implementation": "flash_attention_2",
                        "torch_dtype": torch.bfloat16,
                    },
                    tokenizer_kwargs={"padding_side": "left"},
                    device=device,
                )
                print("已启用 flash_attention_2 加速 (bf16)")
            except Exception as e:
                print(f"无法启用 flash_attention_2 with bf16: {e}")
                try:
                    model = SentenceTransformer(
                        model_name,
                        model_kwargs={
                            "attn_implementation": "flash_attention_2",
                            "torch_dtype": torch.float16,
                        },
                        tokenizer_kwargs={"padding_side": "left"},
                        device=device,
                    )
                    print("已启用 flash_attention_2 加速 (fp16)")
                except Exception as e2:
                    print(f"无法启用 flash_attention_2 with fp16: {e2}")
                    print("使用默认配置加载模型")
                    model = SentenceTransformer(model_name, device=device)
            
            # 加载tokenizer
            print("加载tokenizer...")
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            
            # 分析每个step文件
            results = []
            print("开始分析各step文件...")
            for step_number, file_path in tqdm(step_files, desc="分析step文件"):
                print(f"\n处理 step {step_number}: {os.path.basename(file_path)}")
                
                metrics = analyze_step_file(
                    file_path,
                    tokenizer,
                    model,
                    chunk_size=chunk_size,
                    batch_size=batch_size,
                response_key=response_key,
                accuracy_key=accuracy_key,
                )
                
                results.append({
                    "step": step_number,
                    "file": os.path.basename(file_path),
                    **metrics
                })
                
                print(f"  Step {step_number}: 平均相似度={metrics['avg_similarity']:.4f}, "
                      f"最大相似度={metrics['max_similarity']:.4f}, "
                      f"样本数={metrics['num_samples']}")
        else:
            # 多GPU并行模式
            if not torch.cuda.is_available():
                print("警告：CUDA不可用，回退到单GPU模式")
                num_gpus = 1
            else:
                available_gpus = torch.cuda.device_count()
                num_gpus = min(num_gpus, available_gpus)
                print(f"使用多GPU并行模式: {num_gpus} 个GPU")
                for i in range(num_gpus):
                    print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
            
            if num_gpus == 1:
                # 回退到单GPU模式
                analyze_training_data_by_step(
                    data_dir, output_file, model_name, chunk_size, batch_size,
                    response_key, accuracy_key, file_pattern, max_files, start_step, end_step, gpu_id, None
                )
                return
            
            # 分配任务到各个GPU
            chunk_size_per_gpu = math.ceil(len(step_files) / num_gpus)
            processes = []
            result_queue = Queue()
            
            for gpu_idx in range(num_gpus):
                start_idx = gpu_idx * chunk_size_per_gpu
                end_idx = min((gpu_idx + 1) * chunk_size_per_gpu, len(step_files))
                assigned_files = step_files[start_idx:end_idx]
                
                if len(assigned_files) == 0:
                    continue
                
                p = Process(
                    target=worker_process,
                    args=(
                        gpu_idx,
                        assigned_files,
                        model_name,
                        chunk_size,
                        batch_size,
                        response_key,
                        accuracy_key,
                        result_queue,
                        start_step,
                        end_step,
                    )
                )
                p.start()
                processes.append(p)
                print(f"启动进程处理 GPU {gpu_idx}: {len(assigned_files)} 个文件")
            
            # 收集结果
            results = []
            completed = 0
            total_files = len(step_files)
            
            print(f"\n等待所有GPU完成处理...")
            with tqdm(total=total_files, desc="收集结果") as pbar:
                while completed < total_files:
                    try:
                        result = result_queue.get(timeout=1)
                        results.append(result)
                        completed += 1
                        pbar.update(1)
                    except:
                        # 检查进程是否还在运行
                        if all(not p.is_alive() for p in processes):
                            break
            
            # 等待所有进程完成
            for p in processes:
                p.join()
            
            # 按step排序结果
            results.sort(key=lambda x: x["step"])
    
    # 提取数据用于绘图
    steps = [r["step"] for r in results]
    avg_similarities = [r["avg_similarity"] for r in results]
    max_similarities = [r["max_similarity"] for r in results]
    repetition_ratios_0_8 = [r["repetition_ratio_0.8"] for r in results]
    repetition_ratios_0_9 = [r["repetition_ratio_0.9"] for r in results]
    num_samples_list = [r["num_samples"] for r in results]
    avg_similarities_correct = [r.get("avg_similarity_correct", 0.0) for r in results]
    avg_similarities_incorrect = [r.get("avg_similarity_incorrect", 0.0) for r in results]
    num_correct_list = [r.get("num_correct", 0) for r in results]
    num_incorrect_list = [r.get("num_incorrect", 0) for r in results]
    
    # 保存结果到JSON文件（如果不存在）；存在则仅绘图
    json_output_file = output_file.replace('.png', '_split_by_text.json') if output_file.endswith('.png') else output_file + '_split_by_text.json'
    if os.path.exists(json_output_file):
        print(f"\nFound existing JSON result, skip writing: {json_output_file}")
    else:
        with open(json_output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        print(f"\nSaved results to: {json_output_file}")
    
    # 绘制图表
    plot_output_file = output_file if output_file.endswith('.png') else output_file + '.png'
    
    # 创建图表
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Average similarity vs Step
    ax1 = axes[0, 0]
    ax1.plot(steps, avg_similarities, 'b-o', linewidth=2, markersize=8, label='Avg similarity')
    ax1.set_xlabel('Step', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Average semantic similarity', fontsize=12, fontweight='bold')
    ax1.set_title('Average semantic similarity vs Step', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.legend(fontsize=10)
    
    # 添加趋势线
    if len(steps) > 1:
        z = np.polyfit(steps, avg_similarities, 1)
        p = np.poly1d(z)
        ax1.plot(steps, p(steps), "r--", alpha=0.5, linewidth=1, label=f'Trend (slope={z[0]:.6f})')
        ax1.legend(fontsize=10)
    
    # 2. Max similarity vs Step
    ax2 = axes[0, 1]
    ax2.plot(steps, max_similarities, 'g-o', linewidth=2, markersize=8, label='Max similarity')
    ax2.set_xlabel('Step', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Max semantic similarity', fontsize=12, fontweight='bold')
    ax2.set_title('Max semantic similarity vs Step', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.legend(fontsize=10)
    
    # 3. High similarity ratio (>=0.8) vs Step
    ax3 = axes[1, 0]
    ax3.plot(steps, repetition_ratios_0_8, 'm-o', linewidth=2, markersize=8, label='High-sim ratio (>=0.8)')
    ax3.set_xlabel('Step', fontsize=12, fontweight='bold')
    ax3.set_ylabel('High similarity ratio', fontsize=12, fontweight='bold')
    ax3.set_title('High similarity ratio (>=0.8) vs Step', fontsize=14, fontweight='bold')
    ax3.grid(True, alpha=0.3)
    ax3.legend(fontsize=10)
    
    # 4. High similarity ratio (>=0.9) vs Step
    ax4 = axes[1, 1]
    ax4.plot(steps, repetition_ratios_0_9, 'c-o', linewidth=2, markersize=8, label='High-sim ratio (>=0.9)')
    ax4.set_xlabel('Step', fontsize=12, fontweight='bold')
    ax4.set_ylabel('High similarity ratio', fontsize=12, fontweight='bold')
    ax4.set_title('High similarity ratio (>=0.9) vs Step', fontsize=14, fontweight='bold')
    ax4.grid(True, alpha=0.3)
    ax4.legend(fontsize=10)
    
    plt.tight_layout()
    plt.savefig(plot_output_file, dpi=300, bbox_inches='tight')
    print(f"图表已保存到: {plot_output_file}")
    plt.close()
    
    # 打印统计信息
    print("\n=== Stats ===")
    print(f"Total steps: {len(steps)}")
    print(f"Step range: {min(steps)} - {max(steps)}")
    print(f"Avg similarity range: {min(avg_similarities):.4f} - {max(avg_similarities):.4f}")
    print(f"Avg similarity mean: {np.mean(avg_similarities):.4f} ± {np.std(avg_similarities):.4f}")
    print(f"Max similarity mean: {np.mean(max_similarities):.4f} ± {np.std(max_similarities):.4f}")
    print(f"High-sim ratio (>=0.8) mean: {np.mean(repetition_ratios_0_8):.4f} ± {np.std(repetition_ratios_0_8):.4f}")
    print(f"High-sim ratio (>=0.9) mean: {np.mean(repetition_ratios_0_9):.4f} ± {np.std(repetition_ratios_0_9):.4f}")
    print(f"Total samples: {sum(num_samples_list)}")
    print(f"Avg similarity (correct): {np.mean(avg_similarities_correct):.4f} ± {np.std(avg_similarities_correct):.4f}")
    print(f"Avg similarity (incorrect): {np.mean(avg_similarities_incorrect):.4f} ± {np.std(avg_similarities_incorrect):.4f}")
    print(f"Total correct: {sum(num_correct_list)}, total incorrect: {sum(num_incorrect_list)}")
    
    # 创建主图：平均相似度（用户主要关心的）
    main_plot_file = output_file.replace('.png', '_main.png') if output_file.endswith('.png') else output_file + '_main.png'
    fig, ax = plt.subplots(figsize=(12, 6))
    
    ax.plot(
        steps,
        avg_similarities,
        'b-o',
        linewidth=2.5,
        markersize=10,
        label='Average semantic similarity',
        markerfacecolor='lightblue',
        markeredgecolor='blue',
        markeredgewidth=2,
    )
    has_correct = any(n > 0 for n in num_correct_list)
    has_incorrect = any(n > 0 for n in num_incorrect_list)
    if has_correct:
        ax.plot(
            steps,
            avg_similarities_correct,
            'g--s',
            linewidth=2,
            markersize=8,
            label=f'Correct avg sim (n={sum(num_correct_list)})',
        )
    if has_incorrect:
        ax.plot(
            steps,
            avg_similarities_incorrect,
            'r--^',
            linewidth=2,
            markersize=8,
            label=f'Incorrect avg sim (n={sum(num_incorrect_list)})',
        )
    ax.set_xlabel('Step', fontsize=14, fontweight='bold')
    ax.set_ylabel('Average semantic similarity', fontsize=14, fontweight='bold')
    ax.set_title('Average semantic similarity vs Step', fontsize=16, fontweight='bold')
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(fontsize=12)
    
    # 添加趋势线
    if len(steps) > 1:
        z = np.polyfit(steps, avg_similarities, 1)
        p = np.poly1d(z)
        ax.plot(steps, p(steps), "r--", alpha=0.7, linewidth=2, label=f'Trend (slope={z[0]:.6f})')
        ax.legend(fontsize=12)
    
    # 标注最大值和最小值
    max_idx = np.argmax(avg_similarities)
    min_idx = np.argmin(avg_similarities)
    ax.plot(steps[max_idx], avg_similarities[max_idx], 'ro', markersize=12, label=f'Max: {avg_similarities[max_idx]:.4f}')
    ax.plot(steps[min_idx], avg_similarities[min_idx], 'go', markersize=12, label=f'Min: {avg_similarities[min_idx]:.4f}')
    ax.legend(fontsize=12)
    
    plt.tight_layout()
    plt.savefig(main_plot_file, dpi=300, bbox_inches='tight')
    print(f"主图表已保存到: {main_plot_file}")
    plt.close()
    
    print(f"\n完成！")


def main():
    # 设置multiprocessing启动方法为'spawn'，以支持CUDA多进程
    # 必须在创建任何进程之前设置
    try:
        multiprocessing.set_start_method('spawn', force=True)
    except RuntimeError:
        # 如果已经设置过，忽略错误
        pass
    
    parser = argparse.ArgumentParser(description="分析训练数据，按step统计平均语义重复度")
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="训练数据目录（包含step_X_traindata.jsonl文件）"
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="输出文件路径（会自动添加.png和.json后缀）"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-Embedding-0.6B",
        help="Embedding模型名称 (默认: Qwen/Qwen3-Embedding-0.6B)"
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=512,
        help="每个chunk的token数量 (默认: 512)"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=100,
        help="编码时的batch size，增大可提高GPU利用率 (默认: 32)"
    )
    parser.add_argument(
        "--response_key",
        type=str,
        default="responses",
        help="response字段的键名 (默认: response)"
    )
    parser.add_argument(
        "--accuracy_key",
        type=str,
        default="accuracies",
        help="accuracy字段的键名 (默认: accuracies)"
    )
    parser.add_argument(
        "--file_pattern",
        type=str,
        default=r"step_(\d+)_traindata\.jsonl",
        help="文件名匹配模式（正则表达式，需包含一个捕获组用于提取step编号）(默认: step_(\\d+)_traindata\\.jsonl)"
    )
    parser.add_argument(
        "--max_files",
        type=int,
        default=None,
        help="最多分析的文件数（按step排序后截断）。None表示不限制。"
    )
    parser.add_argument(
        "--start_step",
        type=int,
        default=None,
        help="仅分析 step >= start_step 的文件"
    )
    parser.add_argument(
        "--end_step",
        type=int,
        default=None,
        help="仅分析 step <= end_step 的文件"
    )
    parser.add_argument(
        "--gpu_id",
        type=int,
        default=0,
        help="使用的GPU ID（默认: 0，仅在单GPU模式下使用）"
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=None,
        help="使用的GPU数量（None表示单GPU模式，指定数字则使用多GPU并行模式）"
    )
    
    args = parser.parse_args()
    
    analyze_training_data_by_step(
        data_dir=args.data_dir,
        output_file=args.output,
        model_name=args.model,
        chunk_size=args.chunk_size,
        batch_size=args.batch_size,
        response_key=args.response_key,
        accuracy_key=args.accuracy_key,
        file_pattern=args.file_pattern,
        max_files=args.max_files,
        start_step=args.start_step,
        end_step=args.end_step,
        gpu_id=args.gpu_id,
        num_gpus=args.num_gpus,
    )


if __name__ == "__main__":
    main()

