#!/usr/bin/env python3
"""
语义重复分析脚本
分析JSONL文件中responses的语义重复度
按512 tokens作为一个chunk进行分析
"""

import json
import argparse
from typing import List, Dict, Tuple
import numpy as np
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from collections import defaultdict
import matplotlib.pyplot as plt


def split_text_by_tokens(text: str, tokenizer, chunk_size: int = 512) -> List[str]:
    """
    将文本按照token数量分割成chunks
    
    Args:
        text: 输入文本
        tokenizer: 分词器
        chunk_size: 每个chunk的token数量
    
    Returns:
        分割后的文本chunk列表
    """
    # 对整个文本进行tokenize
    tokens = tokenizer.encode(text, add_special_tokens=False)
    
    chunks = []
    for i in range(0, len(tokens), chunk_size):
        chunk_tokens = tokens[i:i + chunk_size]
        chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
        chunks.append(chunk_text)
    
    return chunks


def batch_split_texts_by_tokens(texts: List[str], tokenizer, chunk_size: int = 512) -> Tuple[List[str], List[Tuple[int, int]]]:
    """
    批量将多个文本按照token数量分割成chunks（加速版本）
    
    Args:
        texts: 输入文本列表
        tokenizer: 分词器
        chunk_size: 每个chunk的token数量
    
    Returns:
        (所有chunks列表, 每个文本的chunk边界列表)
    """
    all_chunks = []
    chunk_boundaries = []
    
    # 批量tokenize所有文本（大幅加速）
    all_tokens_list = tokenizer(
        texts,
        add_special_tokens=False,
        truncation=False,
        padding=False,
        return_attention_mask=False,
    )["input_ids"]
    
    # 为每个文本分割chunks
    for tokens in all_tokens_list:
        start_idx = len(all_chunks)
        
        if len(tokens) == 0:
            # 空文本
            chunk_boundaries.append((start_idx, start_idx))
            continue
        
        # 分割tokens
        chunk_tokens_list = []
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens[i:i + chunk_size]
            chunk_tokens_list.append(chunk_tokens)
        
        # 批量decode chunks（比逐个decode快）
        if chunk_tokens_list:
            chunk_texts = tokenizer.batch_decode(chunk_tokens_list, skip_special_tokens=True)
            all_chunks.extend(chunk_texts)
        
        end_idx = len(all_chunks)
        chunk_boundaries.append((start_idx, end_idx))
    
    return all_chunks, chunk_boundaries


def calculate_semantic_similarity(chunks: List[str], model: SentenceTransformer, batch_size: int = 32) -> np.ndarray:
    """
    计算chunks之间的语义相似度
    
    Args:
        chunks: 文本chunk列表
        model: embedding模型
        batch_size: 编码时的batch size
    
    Returns:
        相似度矩阵
    """
    if len(chunks) == 0:
        return np.array([])
    
    if len(chunks) == 1:
        return np.array([[1.0]])
    
    # 使用CUDA加速编码，指定batch_size来提高GPU利用率
    embeddings = model.encode(
        chunks, 
        convert_to_tensor=True, 
        show_progress_bar=False,
        batch_size=batch_size
    )
    
    # 计算余弦相似度矩阵
    similarity_matrix = model.similarity(embeddings, embeddings)
    
    # 转换为numpy数组（先转为float32，因为numpy不支持bfloat16）
    if torch.is_tensor(similarity_matrix):
        similarity_matrix = similarity_matrix.float().cpu().numpy()
    
    return similarity_matrix


def calculate_repetition_metrics(similarity_matrix: np.ndarray) -> Dict[str, float]:
    """
    基于相似度矩阵计算重复度指标
    
    Args:
        similarity_matrix: 相似度矩阵
    
    Returns:
        包含各种重复度指标的字典
    """
    if similarity_matrix.size == 0:
        return {
            "avg_similarity": 0.0,
            "max_similarity": 0.0,
            "repetition_ratio_0.8": 0.0,
            "repetition_ratio_0.9": 0.0,
            "num_chunks": 0,
        }
    
    n = similarity_matrix.shape[0]
    
    if n == 1:
        return {
            "avg_similarity": 1.0,
            "max_similarity": 1.0,
            "repetition_ratio_0.8": 0.0,
            "repetition_ratio_0.9": 0.0,
            "num_chunks": 1,
        }
    
    # 获取上三角矩阵（不包括对角线），避免重复计算和自相似度
    upper_triangle_indices = np.triu_indices(n, k=1)
    similarities = similarity_matrix[upper_triangle_indices]
    
    # 计算平均相似度（排除自身）
    avg_similarity = np.mean(similarities)
    
    # 计算最大相似度（排除自身）
    max_similarity = np.max(similarities)
    
    # 计算高相似度对的比例
    repetition_ratio_0_8 = np.sum(similarities >= 0.8) / len(similarities)
    repetition_ratio_0_9 = np.sum(similarities >= 0.9) / len(similarities)
    
    return {
        "avg_similarity": float(avg_similarity),
        "max_similarity": float(max_similarity),
        "repetition_ratio_0.8": float(repetition_ratio_0_8),
        "repetition_ratio_0.9": float(repetition_ratio_0_9),
        "num_chunks": n,
    }


def analyze_response(response: str, tokenizer, model: SentenceTransformer, chunk_size: int = 512, batch_size: int = 32) -> Dict:
    """
    分析单个response的语义重复度
    
    Args:
        response: 响应文本
        tokenizer: 分词器
        model: embedding模型
        chunk_size: chunk大小
        batch_size: 编码时的batch size
    
    Returns:
        分析结果
    """
    # 分割文本
    chunks = split_text_by_tokens(response, tokenizer, chunk_size)
    
    # 计算相似度矩阵
    similarity_matrix = calculate_semantic_similarity(chunks, model, batch_size)
    
    # 计算重复度指标
    metrics = calculate_repetition_metrics(similarity_matrix)
    
    return metrics


def process_jsonl_file(
    input_file: str,
    output_file: str,
    model_name: str = "Qwen/Qwen3-Embedding-0.6B",
    chunk_size: int = 512,
    batch_size: int = 32,
    response_key: str = "response",
):
    """
    处理JSONL文件，分析每个response的语义重复度
    
    Args:
        input_file: 输入JSONL文件路径
        output_file: 输出JSONL文件路径
        model_name: embedding模型名称
        chunk_size: chunk大小
        batch_size: 编码时的batch size
        response_key: response字段的键名
    """
    print(f"加载模型: {model_name}")
    
    # 检查CUDA是否可用
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"使用设备: {device}")
    
    if torch.cuda.is_available():
        print(f"GPU设备: {torch.cuda.get_device_name(0)}")
        print(f"GPU显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    # 加载embedding模型（使用flash_attention_2加速和bf16精度）
    try:
        model = SentenceTransformer(
            model_name,
            model_kwargs={
                "attn_implementation": "flash_attention_2",
                "device_map": "auto",
                "torch_dtype": torch.bfloat16,  # 使用bf16精度以支持flash_attention_2
            },
            tokenizer_kwargs={"padding_side": "left"},
        )
        print("已启用 flash_attention_2 加速 (bf16)")
    except Exception as e:
        print(f"无法启用 flash_attention_2 with bf16: {e}")
        # 尝试使用fp16
        try:
            model = SentenceTransformer(
                model_name,
                model_kwargs={
                    "attn_implementation": "flash_attention_2",
                    "device_map": "auto",
                    "torch_dtype": torch.float16,  # 使用fp16精度
                },
                tokenizer_kwargs={"padding_side": "left"},
            )
            print("已启用 flash_attention_2 加速 (fp16)")
        except Exception as e2:
            print(f"无法启用 flash_attention_2 with fp16: {e2}")
            print("使用默认配置加载模型")
            model = SentenceTransformer(model_name, device=device)
    
    # 加载tokenizer用于分割文本
    print("加载tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # 读取输入文件
    print(f"读取输入文件: {input_file}")
    with open(input_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    print(f"共 {len(lines)} 条数据")
    print(f"编码batch_size: {batch_size}")
    print(f"Chunk大小: {chunk_size} tokens\n")
    
    # 批量处理数据以提高GPU利用率
    results = []
    all_data = [json.loads(line) for line in lines]
    
    # 收集所有responses
    print("收集所有responses...")
    all_responses = [data.get(response_key, "") for data in all_data]
    
    # 批量分割所有responses为chunks（加速版本）
    print(f"批量分割所有responses为chunks (chunk_size={chunk_size})...")
    all_chunks, chunk_boundaries = batch_split_texts_by_tokens(all_responses, tokenizer, chunk_size)
    
    print(f"总共 {len(all_chunks)} 个chunks")
    
    # 批量编码所有chunks以提高GPU利用率
    print(f"批量编码所有chunks (batch_size={batch_size})...")
    if torch.cuda.is_available():
        print(f"开始编码前GPU显存使用: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    
    if len(all_chunks) > 0:
        all_embeddings = model.encode(
            all_chunks,
            convert_to_tensor=True,
            show_progress_bar=True,
            batch_size=batch_size
        )
        if torch.cuda.is_available():
            print(f"编码完成后GPU显存使用: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    else:
        all_embeddings = None
    
    # 为每个response计算相似度指标
    # 同时收集相邻chunks的相似度用于统计
    adjacent_similarities_by_idx = defaultdict(list)  # {chunk_idx: [similarity1, similarity2, ...]}
    
    print("计算每个response的语义重复度...")
    for idx, data in enumerate(tqdm(all_data, desc="计算指标")):
        start_idx, end_idx = chunk_boundaries[idx]
        
        if start_idx == end_idx:  # 空response
            metrics = {
                "avg_similarity": 0.0,
                "max_similarity": 0.0,
                "repetition_ratio_0.8": 0.0,
                "repetition_ratio_0.9": 0.0,
                "num_chunks": 0,
            }
        else:
            # 获取该response的embeddings
            response_embeddings = all_embeddings[start_idx:end_idx]
            
            # 计算相似度矩阵
            similarity_matrix = model.similarity(response_embeddings, response_embeddings)
            
            # 转换为numpy
            if torch.is_tensor(similarity_matrix):
                similarity_matrix = similarity_matrix.float().cpu().numpy()
            
            # 计算重复度指标
            metrics = calculate_repetition_metrics(similarity_matrix)
            
            # 收集相邻chunks的相似度 (idx 和 idx+1)
            n_chunks = similarity_matrix.shape[0]
            for chunk_idx in range(n_chunks - 1):
                # chunk_idx 和 chunk_idx+1 的相似度
                adjacent_sim = similarity_matrix[chunk_idx, chunk_idx + 1]
                adjacent_similarities_by_idx[chunk_idx].append(float(adjacent_sim))
        
        # 添加分析结果到原始数据
        data["semantic_repetition_analysis"] = metrics
        results.append(data)
    
    # 清理GPU内存
    if torch.cuda.is_available():
        del all_embeddings
        torch.cuda.empty_cache()
        print(f"清理后GPU显存使用: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    
    # 写入输出文件
    print(f"写入输出文件: {output_file}")
    with open(output_file, 'w', encoding='utf-8') as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')
    
    # 打印统计信息
    print("\n=== 统计信息 ===")
    avg_similarities = [r["semantic_repetition_analysis"]["avg_similarity"] for r in results]
    max_similarities = [r["semantic_repetition_analysis"]["max_similarity"] for r in results]
    repetition_ratios_0_8 = [r["semantic_repetition_analysis"]["repetition_ratio_0.8"] for r in results]
    repetition_ratios_0_9 = [r["semantic_repetition_analysis"]["repetition_ratio_0.9"] for r in results]
    num_chunks_list = [r["semantic_repetition_analysis"]["num_chunks"] for r in results]
    
    print(f"平均相似度: {np.mean(avg_similarities):.4f} ± {np.std(avg_similarities):.4f}")
    print(f"最大相似度: {np.mean(max_similarities):.4f} ± {np.std(max_similarities):.4f}")
    print(f"高相似度比例(>=0.8): {np.mean(repetition_ratios_0_8):.4f} ± {np.std(repetition_ratios_0_8):.4f}")
    print(f"高相似度比例(>=0.9): {np.mean(repetition_ratios_0_9):.4f} ± {np.std(repetition_ratios_0_9):.4f}")
    print(f"平均chunk数量: {np.mean(num_chunks_list):.2f} ± {np.std(num_chunks_list):.2f}")
    
    # 计算相邻chunks相似度统计
    print("\n=== 相邻Chunks相似度统计 ===")
    if adjacent_similarities_by_idx:
        chunk_indices = sorted(adjacent_similarities_by_idx.keys())
        avg_adjacent_similarities = []
        std_adjacent_similarities = []
        sample_counts = []
        
        for chunk_idx in chunk_indices:
            similarities = adjacent_similarities_by_idx[chunk_idx]
            avg_adjacent_similarities.append(np.mean(similarities))
            std_adjacent_similarities.append(np.std(similarities))
            sample_counts.append(len(similarities))
        
        print(f"统计了 {len(chunk_indices)} 个chunk位置")
        print(f"每个位置的样本数范围: {min(sample_counts)} - {max(sample_counts)}")
        
        # 保存统计数据
        stats_output_file = output_file.replace('.jsonl', '_adjacent_similarity_stats.json')
        stats_data = {
            "chunk_indices": chunk_indices,
            "avg_similarities": avg_adjacent_similarities,
            "std_similarities": std_adjacent_similarities,
            "sample_counts": sample_counts
        }
        with open(stats_output_file, 'w', encoding='utf-8') as f:
            json.dump(stats_data, f, ensure_ascii=False, indent=2)
        print(f"统计数据已保存到: {stats_output_file}")
        
        # 生成可视化图表
        plot_output_file = output_file.replace('.jsonl', '_adjacent_similarity_plot.png')
        fig, ax1 = plt.subplots(figsize=(12, 6))
        
        # 左y轴：相似度
        color1 = 'tab:blue'
        ax1.set_xlabel('Chunk Index', fontsize=12)
        ax1.set_ylabel('Average Similarity (chunk[idx] vs chunk[idx+1])', fontsize=12, color=color1)
        line1 = ax1.plot(chunk_indices, avg_adjacent_similarities, 'b-', linewidth=2, label='Average Similarity')
        ax1.fill_between(
            chunk_indices,
            np.array(avg_adjacent_similarities) - np.array(std_adjacent_similarities),
            np.array(avg_adjacent_similarities) + np.array(std_adjacent_similarities),
            alpha=0.3,
            label='±1 Std Dev'
        )
        ax1.tick_params(axis='y', labelcolor=color1)
        ax1.grid(True, alpha=0.3)
        
        # 标注最大值
        max_idx = np.argmax(avg_adjacent_similarities)
        max_value = avg_adjacent_similarities[max_idx]
        max_chunk_idx = chunk_indices[max_idx]
        ax1.plot(max_chunk_idx, max_value, 'ro', markersize=10, label=f'Max: {max_value:.3f}')
        ax1.annotate(
            f'Max: {max_value:.3f}\nChunk: {max_chunk_idx}',
            xy=(max_chunk_idx, max_value),
            xytext=(10, 10),
            textcoords='offset points',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.7),
            arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'),
            fontsize=10
        )
        
        # 右y轴：样本数量
        ax2 = ax1.twinx()
        color2 = 'tab:orange'
        ax2.set_ylabel('Average Sample Count', fontsize=12, color=color2)
        line2 = ax2.plot(chunk_indices, sample_counts, 'o-', color=color2, linewidth=2, markersize=6, label='Sample Count', alpha=0.7)
        ax2.tick_params(axis='y', labelcolor=color2)
        
        # 合并图例
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left', fontsize=10)
        
        plt.title('Adjacent Chunks Similarity and Sample Count by Position', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.savefig(plot_output_file, dpi=300, bbox_inches='tight')
        print(f"可视化图表已保存到: {plot_output_file}")
        plt.close()
    else:
        print("没有找到相邻chunks的相似度数据")
    
    print(f"\n完成！结果已保存到: {output_file}")


def main():
    parser = argparse.ArgumentParser(description="分析JSONL文件中responses的语义重复度")
    parser.add_argument(
        "--input",
        type=str,
        required=True,
        help="输入JSONL文件路径"
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="输出JSONL文件路径"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="Qwen/Qwen3-Embedding-0.6B",
        help="Embedding模型名称 (默认: Qwen/Qwen3-Embedding-0.6B)"
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=512,
        help="每个chunk的token数量 (默认: 512)"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="编码时的batch size，增大可提高GPU利用率 (默认: 32)"
    )
    parser.add_argument(
        "--response_key",
        type=str,
        default="output",
        help="response字段的键名 (默认: output)"
    )
    
    args = parser.parse_args()
    
    process_jsonl_file(
        input_file=args.input,
        output_file=args.output,
        model_name=args.model,
        chunk_size=args.chunk_size,
        batch_size=args.batch_size,
        response_key=args.response_key,
    )


if __name__ == "__main__":
    main()
