#!/usr/bin/env python3
"""
训练数据语义重复度分析脚本（使用 vLLM OpenAI-compatible server）
分析训练过程中收集的生成数据，按step统计平均语义重复度并画图
"""

import json
import argparse
import os
import re
import sys
from typing import List, Dict, Tuple
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# 添加项目根目录到路径，以便导入 semantic_repetition 模块
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/verl/recipe/reward_ours'))
from semantic_repetition import VLLMEmbeddingModel, calculate_semantic_repetition


def analyze_step_file_api(
    file_path: str,
    model: VLLMEmbeddingModel,
    chunk_size: int = 512,
    response_key: str = "responses",
    accuracy_key: str = "accuracies",
) -> Dict[str, float]:
    """
    使用 vLLM API 分析单个step文件的语义重复度
    
    Args:
        file_path: JSONL文件路径
        model: VLLMEmbeddingModel实例
        chunk_size: chunk大小（单位：词）
        response_key: response字段的键名
        accuracy_key: accuracy字段的键名
    
    Returns:
        该step的语义重复度指标
    """
    # 读取文件
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    if len(lines) == 0:
        return {
            "avg_similarity": 0.0,
            "std_similarity": 0.0,
            "num_samples": 0,
            "num_valid": 0,
        }
    
    # 解析数据，展平response与accuracy
    all_data = [json.loads(line) for line in lines]
    valid_responses: List[str] = []
    valid_accuracies: List[float] = []
    
    for data in all_data:
        responses = data.get(response_key, [])
        accuracies = data.get(accuracy_key, [])
        
        for resp, acc in zip(responses, accuracies):
            if isinstance(resp, str) and resp.strip():
                valid_responses.append(resp)
                valid_accuracies.append(acc)
    
    print(f"Total responses: {len(valid_responses)}")
    
    if len(valid_responses) == 0:
        return {
            "avg_similarity": 0.0,
            "std_similarity": 0.0,
            "num_samples": len(all_data),
            "num_valid": 0,
        }
    
    # 调用批量处理的 calculate_semantic_repetition
    avg_scores, std_scores, max_similarities, high_similarity_flags = calculate_semantic_repetition(
        valid_responses,
        chunk_size=chunk_size,
        model=model,
    )
    
    # 过滤有效结果（排除空响应）
    valid_avg = [s for s in avg_scores if s is not None and s != 0]
    valid_std = [s for s in std_scores if s is not None and s != 0]
    # 按准确性分桶
    correct_avg_similarities = []
    incorrect_avg_similarities = []
    correct_repetition = []
    incorrect_repetition = []
    
    for idx, (avg_sim, acc) in enumerate(zip(avg_scores, valid_accuracies)):
        if avg_sim is not None and avg_sim > 0:
            if acc > 0:
                correct_avg_similarities.append(avg_sim)
                correct_repetition.append(high_similarity_flags[idx])
            else:
                incorrect_avg_similarities.append(avg_sim)
                incorrect_repetition.append(high_similarity_flags[idx])

    
    # 计算平均值
    if len(valid_avg) == 0:
        return {
            "avg_similarity": 0.0,
            "std_similarity": 0.0,
            "num_samples": len(valid_responses),
            "num_valid": 0,
        }
    print("Total samples: ", len(valid_responses))
    print("Total valid: ", len(correct_avg_similarities + incorrect_avg_similarities))
    return {
        "avg_similarity": float(np.sum(valid_avg) / len(correct_avg_similarities + incorrect_avg_similarities)),
        "std_similarity": float(np.sum(valid_std) / len(correct_avg_similarities + incorrect_avg_similarities)),
        "correct_repetition": float(np.mean(correct_repetition)) if len(correct_repetition) > 0 else 0.0,
        "incorrect_repetition": float(np.mean(incorrect_repetition)) if len(incorrect_repetition) > 0 else 0.0,
        "num_samples": len(valid_responses),
        "num_valid": len(correct_avg_similarities + incorrect_avg_similarities),
        "avg_similarity_correct": float(np.mean(correct_avg_similarities)) if len(correct_avg_similarities) > 0 else 0.0,
        "avg_similarity_incorrect": float(np.mean(incorrect_avg_similarities)) if len(incorrect_avg_similarities) > 0 else 0.0,
        "num_correct": len(correct_avg_similarities),
        "num_incorrect": len(incorrect_avg_similarities),
    }


def find_step_files(data_dir: str, pattern: str = r"step_(\d+)_traindata\.jsonl") -> List[Tuple[int, str]]:
    """查找所有step文件并提取step编号"""
    step_files = []
    
    for filename in os.listdir(data_dir):
        match = re.match(pattern, filename)
        if match:
            step_number = int(match.group(1))
            file_path = os.path.join(data_dir, filename)
            step_files.append((step_number, file_path))
    
    step_files.sort(key=lambda x: x[0])
    return step_files


def analyze_training_data_by_step_api(
    data_dir: str,
    output_file: str,
    api_base_url: str = "http://10.102.212.39:8000/v1",
    chunk_size: int = 512,
    response_key: str = "responses",
    accuracy_key: str = "accuracies",
    file_pattern: str = r"step_(\d+)_traindata\.jsonl",
    max_files: int = None,
    start_step: int = None,
    end_step: int = None,
): 
    """
    使用 vLLM API 分析训练数据，按step统计平均语义重复度
    
    Args:
        data_dir: 训练数据目录
        output_file: 输出文件路径
        api_base_url: vLLM OpenAI-compatible server 地址（包含 /v1）
        chunk_size: chunk大小（单位：词）
        response_key: response字段的键名
        accuracy_key: accuracy字段的键名
        file_pattern: 文件名匹配模式
        max_files: 最多分析的文件数
        start_step: 起始step
        end_step: 结束step
    """
    # 查找所有step文件
    print(f"查找step文件: {data_dir}")
    step_files = find_step_files(data_dir, file_pattern)
    
    # 过滤step范围
    if start_step is not None or end_step is not None:
        step_files = [
            (s, p) for s, p in step_files
            if (start_step is None or s >= start_step) and (end_step is None or s <= end_step)
        ]
    
    if max_files is not None and max_files > 0:
        step_files = step_files[:max_files]
    
    if len(step_files) == 0:
        print(f"错误：在 {data_dir} 中未找到匹配的文件")
        return
    
    print(f"找到 {len(step_files)} 个step文件")
    print(f"API服务器: {api_base_url}")
    print(f"Chunk大小: {chunk_size} words\n")
    
    # 结果文件
    json_output_file = output_file.replace('.png', '_vllm_50.json') if output_file.endswith('.png') else output_file + '_vllm_50.json'
    
    # 如果已有结果，直接加载
    if os.path.exists(json_output_file):
        print(f"加载已有结果: {json_output_file}")
        with open(json_output_file, 'r', encoding='utf-8') as f:
            results = json.load(f)
        results.sort(key=lambda x: x["step"])
    else:
        # 初始化 vLLM Embedding 模型
        print("使用 vLLM OpenAI-compatible server")
        model = VLLMEmbeddingModel(
            base_url=api_base_url,  # 应该已包含 /v1
            max_retries=3,
            timeout=120
        )
        
        # 分析每个step文件
        results = []
        print("开始分析各step文件...")
        for step_number, file_path in tqdm(step_files, desc="分析step文件"):
            print(f"\n处理 step {step_number}: {os.path.basename(file_path)}")
            
            metrics = analyze_step_file_api(
                file_path,
                model,
                chunk_size=chunk_size,
                response_key=response_key,
                accuracy_key=accuracy_key,
            )
            
            results.append({
                "step": step_number,
                "file": os.path.basename(file_path),
                **metrics
            })
            
            print(f"  Step {step_number}: 平均相似度={metrics['avg_similarity']:.4f}, "
                  f"标准差={metrics['std_similarity']:.4f}, "
                  f"有效样本={metrics['num_valid']}/{metrics['num_samples']}")
        
        # 保存结果
        with open(json_output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        print(f"\n结果已保存: {json_output_file}")
    
    # 提取数据用于绘图
    steps = [r["step"] for r in results]
    avg_similarities = [r["avg_similarity"] for r in results]
    std_similarities = [r["std_similarity"] for r in results]
    num_samples_list = [r["num_samples"] for r in results]
    num_valid_list = [r["num_valid"] for r in results]
    avg_similarities_correct = [r.get("avg_similarity_correct", 0.0) for r in results]
    avg_similarities_incorrect = [r.get("avg_similarity_incorrect", 0.0) for r in results]
    num_correct_list = [r.get("num_correct", 0) for r in results]
    num_incorrect_list = [r.get("num_incorrect", 0) for r in results]
    
    # 提取重复率数据 (High Similarity Rate)
    correct_repetition_list = [r.get("correct_repetition", 0.0) for r in results]
    incorrect_repetition_list = [r.get("incorrect_repetition", 0.0) for r in results]
    avg_repetition_list = [(r.get("correct_repetition", 0.0) * r.get("num_correct", 0) + 
                            r.get("incorrect_repetition", 0.0) * r.get("num_incorrect", 0)) / 
                           max(r.get("num_valid", 1), 1) for r in results]
    
    # 绘制主图：平均相似度
    plot_output_file = output_file.replace('.json', '_vllm_50.png') if output_file.endswith('.json') else output_file + '_vllm_50.png'
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    ax.plot(
        steps, avg_similarities, 'b-o', linewidth=2.5, markersize=10,
        label='Average semantic similarity',
        markerfacecolor='lightblue', markeredgecolor='blue', markeredgewidth=2,
    )
    
    # 添加正确/错误样本的曲线
    has_correct = any(n > 0 for n in num_correct_list)
    has_incorrect = any(n > 0 for n in num_incorrect_list)
    
    if has_correct:
        ax.plot(steps, avg_similarities_correct, 'g--s', linewidth=2, markersize=8,
                label=f'Correct avg sim (n={sum(num_correct_list)})')
    if has_incorrect:
        ax.plot(steps, avg_similarities_incorrect, 'r--^', linewidth=2, markersize=8,
                label=f'Incorrect avg sim (n={sum(num_incorrect_list)})')
    
    ax.set_xlabel('Step', fontsize=14, fontweight='bold')
    ax.set_ylabel('Average semantic similarity', fontsize=14, fontweight='bold')
    ax.set_title('Average semantic similarity vs Step (vLLM)', fontsize=16, fontweight='bold')
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(fontsize=12)
    
    # 添加趋势线
    if len(steps) > 1:
        z = np.polyfit(steps, avg_similarities, 1)
        p = np.poly1d(z)
        ax.plot(steps, p(steps), "r--", alpha=0.7, linewidth=2, 
                label=f'Trend (slope={z[0]:.6f})')
        ax.legend(fontsize=12)
    
    # 标注最大值和最小值
    if len(avg_similarities) > 0:
        max_idx = np.argmax(avg_similarities)
        min_idx = np.argmin(avg_similarities)
        ax.plot(steps[max_idx], avg_similarities[max_idx], 'ro', markersize=12,
                label=f'Max: {avg_similarities[max_idx]:.4f}')
        ax.plot(steps[min_idx], avg_similarities[min_idx], 'go', markersize=12,
                label=f'Min: {avg_similarities[min_idx]:.4f}')
        ax.legend(fontsize=12)
    
    plt.tight_layout()
    plt.savefig(plot_output_file, dpi=300, bbox_inches='tight')
    print(f"图表已保存: {plot_output_file}")
    plt.close()
    
    # 绘制重复率图表 (High Similarity Rate)
    repetition_plot_file = output_file.replace('.png', '_vllm_50_repetition_rate.png') if output_file.endswith('.png') else output_file + '_vllm_50_repetition_rate.png'
    fig, ax = plt.subplots(figsize=(12, 6))
    
    ax.plot(steps, avg_repetition_list, 'b-o', linewidth=2.5, markersize=10,
            label='Total Repetition Rate', markerfacecolor='lightblue')
    
    if has_correct:
        ax.plot(steps, correct_repetition_list, 'g--s', linewidth=2, markersize=8,
                label='Correct Repetition Rate')
    if has_incorrect:
        ax.plot(steps, incorrect_repetition_list, 'r--^', linewidth=2, markersize=8,
                label='Incorrect Repetition Rate')
                
    ax.set_xlabel('Step', fontsize=14, fontweight='bold')
    ax.set_ylabel('Repetition Rate (High Similarity %)', fontsize=14, fontweight='bold')
    ax.set_title('Repetition Rate vs Step', fontsize=16, fontweight='bold')
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(fontsize=12)
    
    plt.tight_layout()
    plt.savefig(repetition_plot_file, dpi=300, bbox_inches='tight')
    print(f"重复率图表已保存: {repetition_plot_file}")
    plt.close()
    
    # 绘制详细图表（修改为 2x3 布局）
    detailed_plot_file = plot_output_file.replace('_vllm_50.png', '_vllm_50_detailed.png')
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. 平均相似度
    ax1 = axes[0, 0]
    ax1.plot(steps, avg_similarities, 'b-o', linewidth=2, markersize=8)
    ax1.set_xlabel('Step')
    ax1.set_ylabel('Average similarity')
    ax1.set_title('Average similarity vs Step')
    ax1.grid(True, alpha=0.3)
    
    # 2. 标准差
    ax2 = axes[0, 1]
    ax2.plot(steps, std_similarities, 'g-o', linewidth=2, markersize=8)
    ax2.set_xlabel('Step')
    ax2.set_ylabel('Std similarity')
    ax2.set_title('Std similarity vs Step')
    ax2.grid(True, alpha=0.3)
    
    # 3. 有效样本数
    ax3 = axes[0, 2]
    ax3.plot(steps, num_valid_list, 'm-o', linewidth=2, markersize=8)
    ax3.set_xlabel('Step')
    ax3.set_ylabel('Valid samples')
    ax3.set_title('Valid samples vs Step')
    ax3.grid(True, alpha=0.3)
    
    # 4. 正确/错误相似度对比
    ax4 = axes[1, 0]
    if has_correct:
        ax4.plot(steps, avg_similarities_correct, 'g-s', linewidth=2, markersize=8, label='Correct')
    if has_incorrect:
        ax4.plot(steps, avg_similarities_incorrect, 'r-^', linewidth=2, markersize=8, label='Incorrect')
    ax4.set_xlabel('Step')
    ax4.set_ylabel('Avg similarity')
    ax4.set_title('Similarity: Correct vs Incorrect')
    ax4.grid(True, alpha=0.3)
    ax4.legend()
    
    # 5. 重复率对比
    ax5 = axes[1, 1]
    if has_correct:
        ax5.plot(steps, correct_repetition_list, 'g-s', linewidth=2, markersize=8, label='Correct')
    if has_incorrect:
        ax5.plot(steps, incorrect_repetition_list, 'r-^', linewidth=2, markersize=8, label='Incorrect')
    ax5.set_xlabel('Step')
    ax5.set_ylabel('Repetition Rate')
    ax5.set_title('Repetition: Correct vs Incorrect')
    ax5.grid(True, alpha=0.3)
    ax5.legend()
    
    # 6. 正确/错误样本数量对比
    ax6 = axes[1, 2]
    ax6.bar(np.array(steps)-2, num_correct_list, width=4, color='g', alpha=0.6, label='Correct')
    ax6.bar(np.array(steps)+2, num_incorrect_list, width=4, color='r', alpha=0.6, label='Incorrect')
    ax6.set_xlabel('Step')
    ax6.set_ylabel('Count')
    ax6.set_title('Sample Counts: Correct vs Incorrect')
    ax6.grid(True, alpha=0.3)
    ax6.legend()
    
    plt.tight_layout()
    plt.savefig(detailed_plot_file, dpi=300, bbox_inches='tight')
    print(f"详细图表已保存: {detailed_plot_file}")
    plt.close()
    
    # 打印统计信息
    print("\n=== 统计信息 ===")
    print(f"总step数: {len(steps)}")
    print(f"Step范围: {min(steps)} - {max(steps)}")
    print(f"平均相似度范围: {min(avg_similarities):.4f} - {max(avg_similarities):.4f}")
    print(f"平均相似度均值: {np.mean(avg_similarities):.4f} ± {np.std(avg_similarities):.4f}")
    print(f"标准差均值: {np.mean(std_similarities):.4f} ± {np.std(std_similarities):.4f}")
    print(f"总样本数: {sum(num_samples_list)}")
    print(f"有效样本数: {sum(num_valid_list)}")
    if has_correct:
        print(f"正确样本平均相似度: {np.mean(avg_similarities_correct):.4f} ± {np.std(avg_similarities_correct):.4f}")
    if has_incorrect:
        print(f"错误样本平均相似度: {np.mean(avg_similarities_incorrect):.4f} ± {np.std(avg_similarities_incorrect):.4f}")
    
    print(f"平均重复率: {np.mean(avg_repetition_list):.4f} ± {np.std(avg_repetition_list):.4f}")
    if has_correct:
        print(f"正确样本平均重复率: {np.mean(correct_repetition_list):.4f} ± {np.std(correct_repetition_list):.4f}")
    if has_incorrect:
        print(f"错误样本平均重复率: {np.mean(incorrect_repetition_list):.4f} ± {np.std(incorrect_repetition_list):.4f}")
        
    print(f"总正确样本: {sum(num_correct_list)}, 总错误样本: {sum(num_incorrect_list)}")
    print(f"\n完成！")


def main():
    parser = argparse.ArgumentParser(description="使用 vLLM API 分析训练数据的语义重复度")
    parser.add_argument("--data_dir", type=str, required=True, help="训练数据目录")
    parser.add_argument("--output", type=str, required=True, help="输出文件路径")
    parser.add_argument("--api_url", type=str, default="http://10.102.212.39:8000/v1", 
                       help="vLLM OpenAI-compatible server 地址 (默认包含 /v1)")
    parser.add_argument("--chunk_size", type=int, default=50, help="每个chunk的词数 (默认: 512)")
    parser.add_argument("--response_key", type=str, default="responses", help="response字段的键名 (默认: responses)")
    parser.add_argument("--accuracy_key", type=str, default="accuracies", help="accuracy字段的键名 (默认: accuracies)")
    parser.add_argument("--file_pattern", type=str, default=r"step_(\d+)_traindata\.jsonl", 
                       help="文件名匹配模式")
    parser.add_argument("--max_files", type=int, default=None, help="最多分析的文件数")
    parser.add_argument("--start_step", type=int, default=None, help="起始step")
    parser.add_argument("--end_step", type=int, default=None, help="结束step")
    
    args = parser.parse_args()
    
    analyze_training_data_by_step_api(
        data_dir=args.data_dir,
        output_file=args.output,
        api_base_url=args.api_url,
        chunk_size=args.chunk_size,
        response_key=args.response_key,
        accuracy_key=args.accuracy_key,
        file_pattern=args.file_pattern,
        max_files=args.max_files,
        start_step=args.start_step,
        end_step=args.end_step,
    )


if __name__ == "__main__":
    main()

