#!/usr/bin/env python3
"""
对比多个训练数据语义重复度分析结果
绘制三个图：avg similarity、正确case的similarity、错误case的similarity
每个图上显示多条线，分别对应多个文件
"""

import json
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict

# 尝试导入scipy，如果不可用则使用numpy实现
try:
    from scipy.ndimage import uniform_filter1d
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False


def load_json_file(file_path: str) -> List[Dict]:
    """
    加载JSON文件
    
    Args:
        file_path: JSON文件路径
    
    Returns:
        结果列表
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # 确保按step排序
    if isinstance(data, list):
        data.sort(key=lambda x: x.get("step", 0))
    
    return data


def smooth_data(data: np.ndarray, window_size: int = 3) -> np.ndarray:
    """
    使用移动平均平滑数据
    
    Args:
        data: 输入数据数组
        window_size: 平滑窗口大小
    
    Returns:
        平滑后的数据
    """
    if len(data) < window_size or window_size <= 1:
        return data
    
    # 确保窗口大小是奇数
    if window_size % 2 == 0:
        window_size += 1
    
    if HAS_SCIPY:
        # 使用scipy的uniform_filter1d（更快）
        smoothed = uniform_filter1d(data, size=window_size, mode='nearest')
    else:
        # 使用numpy实现简单移动平均
        half_window = window_size // 2
        smoothed = np.zeros_like(data)
        for i in range(len(data)):
            start = max(0, i - half_window)
            end = min(len(data), i + half_window + 1)
            smoothed[i] = np.mean(data[start:end])
    
    return smoothed


def extract_data(results: List[Dict], smooth_window: int = 0) -> Dict:
    """
    从结果中提取绘图所需的数据
    
    Args:
        results: 结果列表
        smooth_window: 平滑窗口大小（0表示不平滑）
    
    Returns:
        包含steps、avg_similarities等数据的字典
    """
    steps = np.array([r.get("step", 0) for r in results])
    avg_similarities = np.array([r.get("avg_similarity", 0.0) for r in results])
    avg_similarities_correct = np.array([r.get("avg_similarity_correct", 0.0) for r in results])
    avg_similarities_incorrect = np.array([r.get("avg_similarity_incorrect", 0.0) for r in results])
    
    # 如果指定了平滑窗口，进行平滑处理
    if smooth_window > 0:
        avg_similarities = smooth_data(avg_similarities, smooth_window)
        avg_similarities_correct = smooth_data(avg_similarities_correct, smooth_window)
        avg_similarities_incorrect = smooth_data(avg_similarities_incorrect, smooth_window)
    
    return {
        "steps": steps,
        "avg_similarities": avg_similarities,
        "avg_similarities_correct": avg_similarities_correct,
        "avg_similarities_incorrect": avg_similarities_incorrect,
    }


def plot_comparison(
    file_paths: List[str],
    output_file: str,
    labels: List[str] = None,
    smooth_window: int = 3,
):
    """
    对比多个文件的语义重复度结果并绘制图表
    
    Args:
        file_paths: JSON文件路径列表
        output_file: 输出图片路径
        labels: 文件标签列表（默认使用文件名）
        smooth_window: 平滑窗口大小（0表示不平滑，默认3）
    """
    num_files = len(file_paths)
    if num_files == 0:
        print("错误：至少需要提供一个文件")
        return
    
    # 加载数据
    all_results = []
    for i, file_path in enumerate(file_paths):
        print(f"加载文件{i+1}: {file_path}")
        results = load_json_file(file_path)
        print(f"  找到 {len(results)} 个数据点")
        all_results.append(results)
    
    # 提取数据（应用平滑）
    all_data = [extract_data(results, smooth_window=smooth_window) for results in all_results]
    
    if smooth_window > 0:
        print(f"使用窗口大小 {smooth_window} 进行平滑处理")
    
    # 设置标签
    if labels is None:
        labels = [os.path.basename(fp).replace('.json', '') for fp in file_paths]
    elif len(labels) < num_files:
        # 如果标签数量不足，用文件名补充
        labels.extend([os.path.basename(fp).replace('.json', '') for fp in file_paths[len(labels):]])
    
    # 定义颜色和标记样式
    colors = ['b', 'r', 'g', 'm', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive']
    markers = ['o', 's', '^', 'v', 'D', 'p', '*', 'h', 'X', 'd']
    linestyles = ['-', '--', '-.', ':', '-', '--', '-.', ':', '-', '--']
    
    # 创建三个子图
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # 绘制每个文件的数据
    for idx, (data, label) in enumerate(zip(all_data, labels)):
        color = colors[idx % len(colors)]
        marker = markers[idx % len(markers)]
        linestyle = linestyles[idx % len(linestyles)]
        
        # 1. Average similarity
        axes[0].plot(
            data["steps"],
            data["avg_similarities"],
            color=color,
            marker=marker,
            linestyle=linestyle,
            linewidth=2,
            markersize=6,
            label=label,
            alpha=0.8,
        )
        
        # 2. Correct case similarity
        axes[1].plot(
            data["steps"],
            data["avg_similarities_correct"],
            color=color,
            marker=marker,
            linestyle=linestyle,
            linewidth=2,
            markersize=6,
            label=label,
            alpha=0.8,
        )
        
        # 3. Incorrect case similarity
        axes[2].plot(
            data["steps"],
            data["avg_similarities_incorrect"],
            color=color,
            marker=marker,
            linestyle=linestyle,
            linewidth=2,
            markersize=6,
            label=label,
            alpha=0.8,
        )
    
    # 设置子图属性
    for ax, title in zip(axes, ['Average semantic similarity', 'Correct case similarity', 'Incorrect case similarity']):
        ax.set_xlabel('Step', fontsize=12, fontweight='bold')
        ax.set_ylabel('Average semantic similarity', fontsize=12, fontweight='bold')
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3, linestyle='--')
        ax.legend(fontsize=10)
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"\n图表已保存到: {output_file}")
    plt.close()
    
    # 打印统计信息
    print("\n=== 统计信息 ===")
    for idx, (data, label) in enumerate(zip(all_data, labels)):
        print(f"\n文件{idx+1} ({label}):")
        print(f"  Steps: {len(data['steps'])}")
        print(f"  Step range: {min(data['steps'])} - {max(data['steps'])}")
        print(f"  Avg similarity: {np.mean(data['avg_similarities']):.4f} ± {np.std(data['avg_similarities']):.4f}")
        print(f"  Correct similarity: {np.mean(data['avg_similarities_correct']):.4f} ± {np.std(data['avg_similarities_correct']):.4f}")
        print(f"  Incorrect similarity: {np.mean(data['avg_similarities_incorrect']):.4f} ± {np.std(data['avg_similarities_incorrect']):.4f}")
    
    print(f"\n完成！")


def main():
    parser = argparse.ArgumentParser(description="对比多个训练数据语义重复度分析结果")
    parser.add_argument(
        "--files",
        type=str,
        nargs='+',
        default=[
            "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/baseline-8k-minibsz32/training_data_vllm.json",
            "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/stage1-additive-length-penalty/training_data_vllm.json",
            "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/stage1-additive-skip-right/training_data_vllm.json"
        ],
        help="JSON文件路径列表（可以指定多个文件）"
    )
    parser.add_argument(
        "--output",
        type=str,
        default="/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/compare_training.png",
        help="输出图片路径"
    )
    parser.add_argument(
        "--labels",
        type=str,
        nargs='+',
        default=["Baseline", "Additive", "Skip Right"],
        help="文件标签列表（默认使用文件名，数量应与文件数量一致）"
    )
    parser.add_argument(
        "--smooth",
        type=int,
        default=3,
        help="平滑窗口大小（0表示不平滑，默认3）"
    )
    
    args = parser.parse_args()
    
    # 检查文件是否存在
    for file_path in args.files:
        if not os.path.exists(file_path):
            print(f"错误：文件不存在: {file_path}")
            return
    
    # 检查标签数量
    if args.labels is not None and len(args.labels) != len(args.files):
        print(f"警告：标签数量({len(args.labels)})与文件数量({len(args.files)})不一致，将使用文件名作为标签")
        args.labels = None
    
    plot_comparison(
        file_paths=args.files,
        output_file=args.output,
        labels=args.labels,
        smooth_window=args.smooth,
    )


if __name__ == "__main__":
    main()

