import json
from collections import Counter, defaultdict
import os
import sys
import glob
import re
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import uniform_filter1d

# 设置matplotlib中文字体和样式
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 平滑函数
def smooth_data(data, window_size=10):
    """使用移动平均平滑数据"""
    if len(data) < window_size:
        return data
    return uniform_filter1d(data, size=window_size, mode='nearest')


# 使用方法:
# 单个目录: python script.py <directory_path> <save_name>
# 多个目录: python script.py <save_name> <dir1:name1> <dir2:name2> <dir3:name3>
# 例如: python script.py comparison /path/dir1:exp1 /path/dir2:exp2 /path/dir3:exp3

if len(sys.argv) < 3:
    print("用法1 (单目录): python script.py <directory_path> <save_name>")
    print("用法2 (多目录): python script.py <save_name> <dir1:name1> <dir2:name2> [dir3:name3]")
    exit(1)

# 解析参数
directories_info = []
if ':' in sys.argv[2]:
    # 多目录模式
    save_name = sys.argv[1]
    for arg in sys.argv[2:]:
        if ':' in arg:
            dir_path, dir_name = arg.split(':', 1)
            directories_info.append({'path': dir_path, 'name': dir_name})
        else:
            print(f"警告：参数格式错误，跳过: {arg}")
else:
    # 单目录模式（向后兼容）
    directory_path = sys.argv[1]
    save_name = sys.argv[2]
    directories_info = [{'path': directory_path, 'name': save_name}]

print(f"将分析 {len(directories_info)} 个目录")
for info in directories_info:
    print(f"  - {info['name']}: {info['path']}")

# 存储所有目录的统计数据
all_stats = {}

for dir_info in directories_info:
    directory_path = dir_info['path']
    dir_name = dir_info['name']
    
    print(f"\n处理目录: {dir_name} ({directory_path})")
    
    # 检查目录是否存在
    if not os.path.exists(directory_path):
        print(f"错误：目录不存在: {directory_path}")
        continue
    
    # 获取所有的step文件
    step_files = glob.glob(os.path.join(directory_path, "step_*_traindata.jsonl"))
    if not step_files:
        print(f"错误：在目录 {directory_path} 中没有找到step_*_traindata.jsonl文件")
        continue
    
    print(f"找到 {len(step_files)} 个step文件")
    
    # 存储该目录每个step的统计数据
    step_stats = {}
    
    for file_path in step_files:
        # 从文件名提取step数字
        filename = os.path.basename(file_path)
        match = re.search(r'step_(\d+)_traindata\.jsonl', filename)
        if not match:
            continue
        
        step_num = int(match.group(1))
        
        # 统计该step的数据
        accuracy_counts = Counter()
        total_items = 0
        anomaly_cases = []  # 存储异常case
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    try:
                        item = json.loads(line.strip())
                        accuracies = item['accuracies']
                        finish_reasons = item.get('finish_reasons', [])
                        
                        # 检查finish_reason是length但acc是True的异常情况
                        for i, (acc, reason) in enumerate(zip(accuracies, finish_reasons)):
                            if reason == "length" and acc == True:
                                anomaly_cases.append({
                                    'step': step_num,
                                    'line': line_num,
                                    'index': i,
                                    'item': item["responses"][i],
                                    'ground_truth': item["ground_truth"]
                                })
                        
                        # 统计正确的数量
                        correct_count = sum(1 for acc in accuracies if acc)
                        total_count = len(accuracies)
                        
                        accuracy_counts[(correct_count, total_count)] += 1
                        total_items += 1
                        
                    except Exception as e:
                        continue
            
            # 保存该step的统计
            step_stats[step_num] = {
                'total': total_items,
                'all_correct': accuracy_counts.get((8, 8), 0),  # 全对 8/8
                'all_wrong': accuracy_counts.get((0, 8), 0),     # 全错 0/8
                'accuracy_counts': accuracy_counts,
                'anomaly_cases': anomaly_cases
            }
            
            anomaly_msg = f", 异常={len(anomaly_cases)}" if anomaly_cases else ""
            print(f"Step {step_num}: 总数={total_items}, 全对={step_stats[step_num]['all_correct']}, 全错={step_stats[step_num]['all_wrong']}{anomaly_msg}")
            
        except Exception as e:
            print(f"警告：处理文件 {file_path} 失败: {e}")
    
    # 保存该目录的统计数据
    all_stats[dir_name] = step_stats

# 获取所有step（取所有目录的并集）
all_steps_set = set()
for step_stats in all_stats.values():
    all_steps_set.update(step_stats.keys())
sorted_steps = sorted(all_steps_set)

# 为每个目录提取数据
all_dir_data = {}
for dir_name, step_stats in all_stats.items():
    dir_sorted_steps = sorted(step_stats.keys())
    
    steps = []
    all_correct_counts = []
    all_wrong_counts = []
    all_counts = []
    total_counts = []
    anomaly_counts = []
    one_correct_counts = []  # 1/8
    two_correct_counts = []  # 2/8
    three_correct_counts = []  # 3/8
    low_accuracy_sum_counts = []  # 1/8 + 2/8 + 3/8
    
    for step in dir_sorted_steps:
        steps.append(step)
        all_correct_counts.append(step_stats[step]['all_correct'])
        all_wrong_counts.append(step_stats[step]['all_wrong'])
        all_counts.append(step_stats[step]['all_correct'] + step_stats[step]['all_wrong'])
        total_counts.append(step_stats[step]['total'])
        anomaly_counts.append(len(step_stats[step]['anomaly_cases']))
        one_count = step_stats[step]['accuracy_counts'].get((1, 8), 0)
        two_count = step_stats[step]['accuracy_counts'].get((2, 8), 0)
        three_count = step_stats[step]['accuracy_counts'].get((3, 8), 0)
        one_correct_counts.append(one_count)
        two_correct_counts.append(two_count)
        three_correct_counts.append(three_count)
        low_accuracy_sum_counts.append(one_count + two_count + three_count)
    
    all_dir_data[dir_name] = {
        'steps': steps,
        'all_correct_counts': all_correct_counts,
        'all_wrong_counts': all_wrong_counts,
        'all_counts': all_counts,
        'total_counts': total_counts,
        'anomaly_counts': anomaly_counts,
        'one_correct_counts': one_correct_counts,
        'two_correct_counts': two_correct_counts,
        'three_correct_counts': three_correct_counts,
        'low_accuracy_sum_counts': low_accuracy_sum_counts,
        'step_stats': step_stats
    }

# 创建输出目录
os.makedirs('train_data_analysis', exist_ok=True)

# 为每个目录单独画全对/全错图和异常图
for dir_name, data in all_dir_data.items():
    steps = data['steps']
    all_correct_counts = data['all_correct_counts']
    all_wrong_counts = data['all_wrong_counts']
    all_counts = data['all_counts']
    total_counts = data['total_counts']
    anomaly_counts = data['anomaly_counts']
    
    # 画全对/全错图
    plt.figure(figsize=(12, 6))
    smoothed_correct = smooth_data(all_correct_counts, window_size=3)
    smoothed_wrong = smooth_data(all_wrong_counts, window_size=3)
    smoothed_all = smooth_data(all_counts, window_size=3)
    plt.plot(steps, smoothed_correct, marker='o', linewidth=3, markersize=8, label='All correct (8/8)', color='#2ECC71', alpha=0.9)
    plt.plot(steps, smoothed_wrong, marker='s', linewidth=3, markersize=8, label='All wrong (0/8)', color='#E74C3C', alpha=0.9)
    plt.plot(steps, smoothed_all, marker='^', linewidth=3, markersize=8, label='Total', color='#3498DB', alpha=0.9)
    plt.axhline(y=sum(total_counts)/len(total_counts), color='#34495E', linestyle='--', linewidth=2, alpha=0.6)
    plt.xlabel('Training Step', fontsize=12, fontweight='bold')
    plt.ylabel('Sample Count', fontsize=12, fontweight='bold')
    plt.title(f'All wrong and all correct samples - {dir_name} (smoothed)', fontsize=14, fontweight='bold')
    plt.legend(fontsize=11, framealpha=0.9)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    current_output = f'train_data_analysis/{dir_name}_accuracy_distribution.png'
    plt.savefig(current_output, dpi=300, bbox_inches='tight')
    print(f"图片已保存: {current_output}")
    plt.close()
    
    # 画异常数量的图
    plt.figure(figsize=(12, 6))
    smoothed_anomaly = smooth_data(anomaly_counts, window_size=3)
    plt.plot(steps, smoothed_anomaly, marker='D', linewidth=3, markersize=8, 
             label='Anomaly cases (finish_reason=length but acc=True)', color='#F39C12', alpha=0.9)
    plt.xlabel('Training Step', fontsize=12, fontweight='bold')
    plt.ylabel('Anomaly Count', fontsize=12, fontweight='bold')
    plt.title(f'Anomaly cases - {dir_name} (smoothed)', fontsize=14, fontweight='bold')
    plt.legend(fontsize=11, framealpha=0.9)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    anomaly_output = f'train_data_analysis/{dir_name}_anomaly_distribution.png'
    plt.savefig(anomaly_output, dpi=300, bbox_inches='tight')
    print(f"异常分布图已保存: {anomaly_output}")
    plt.close()

# 画1/8, 2/8, 3/8的对比图（多个目录在同一张图上）
# 使用高对比度的颜色方案，更容易区分
colors = [
    '#E74C3C',  # 鲜红色
    '#3498DB',  # 亮蓝色
    '#2ECC71',  # 鲜绿色
    '#F39C12',  # 橙色
    '#9B59B6',  # 紫色
    '#1ABC9C',  # 青色
    '#E67E22',  # 深橙色
    '#34495E',  # 深蓝灰色
]
markers = ['o', 's', '^', 'D', 'v', '<', 'p', 'h']

# 画1/8对比
plt.figure(figsize=(14, 6))
for idx, (dir_name, data) in enumerate(all_dir_data.items()):
    smoothed_data = smooth_data(data['one_correct_counts'], window_size=3)
    plt.plot(data['steps'], smoothed_data, 
             marker=markers[idx % len(markers)], linewidth=3, markersize=8,
             label=f'{dir_name}', color=colors[idx % len(colors)], alpha=0.9)
plt.xlabel('Training Step', fontsize=12, fontweight='bold')
plt.ylabel('Sample Count', fontsize=12, fontweight='bold')
plt.title('1/8 correct samples comparison (smoothed)', fontsize=14, fontweight='bold')
plt.legend(fontsize=11, framealpha=0.9)
plt.grid(True, alpha=0.3)
plt.tight_layout()
one_output = f'train_data_analysis/{save_name}_1of8_comparison.png'
plt.savefig(one_output, dpi=300, bbox_inches='tight')
print(f"\n1/8对比图已保存: {one_output}")
plt.close()

# 画2/8对比
plt.figure(figsize=(14, 6))
for idx, (dir_name, data) in enumerate(all_dir_data.items()):
    smoothed_data = smooth_data(data['two_correct_counts'], window_size=3)
    plt.plot(data['steps'], smoothed_data, 
             marker=markers[idx % len(markers)], linewidth=3, markersize=8,
             label=f'{dir_name}', color=colors[idx % len(colors)], alpha=0.9)
plt.xlabel('Training Step', fontsize=12, fontweight='bold')
plt.ylabel('Sample Count', fontsize=12, fontweight='bold')
plt.title('2/8 correct samples comparison (smoothed)', fontsize=14, fontweight='bold')
plt.legend(fontsize=11, framealpha=0.9)
plt.grid(True, alpha=0.3)
plt.tight_layout()
two_output = f'train_data_analysis/{save_name}_2of8_comparison.png'
plt.savefig(two_output, dpi=300, bbox_inches='tight')
print(f"2/8对比图已保存: {two_output}")
plt.close()

# 画3/8对比
plt.figure(figsize=(14, 6))
for idx, (dir_name, data) in enumerate(all_dir_data.items()):
    smoothed_data = smooth_data(data['three_correct_counts'], window_size=3)
    plt.plot(data['steps'], smoothed_data, 
             marker=markers[idx % len(markers)], linewidth=3, markersize=8,
             label=f'{dir_name}', color=colors[idx % len(colors)], alpha=0.9)
plt.xlabel('Training Step', fontsize=12, fontweight='bold')
plt.ylabel('Sample Count', fontsize=12, fontweight='bold')
plt.title('3/8 correct samples comparison (smoothed)', fontsize=14, fontweight='bold')
plt.legend(fontsize=11, framealpha=0.9)
plt.grid(True, alpha=0.3)
plt.tight_layout()
three_output = f'train_data_analysis/{save_name}_3of8_comparison.png'
plt.savefig(three_output, dpi=300, bbox_inches='tight')
print(f"3/8对比图已保存: {three_output}")
plt.close()

# 画1/8 + 2/8 + 3/8总和对比图（主图）
plt.figure(figsize=(14, 6))
for idx, (dir_name, data) in enumerate(all_dir_data.items()):
    smoothed_data = smooth_data(data['low_accuracy_sum_counts'], window_size=3)
    plt.plot(data['steps'], smoothed_data, 
             marker=markers[idx % len(markers)], linewidth=3.5, markersize=9,
             label=f'{dir_name}', color=colors[idx % len(colors)], alpha=0.9)
plt.xlabel('Training Step', fontsize=13, fontweight='bold')
plt.ylabel('Sample Count', fontsize=13, fontweight='bold')
plt.title('Low accuracy samples (1/8 + 2/8 + 3/8) comparison (smoothed)', fontsize=15, fontweight='bold')
plt.legend(fontsize=12, framealpha=0.9, loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()
sum_output = f'train_data_analysis/{save_name}_low_accuracy_sum.png'
plt.savefig(sum_output, dpi=300, bbox_inches='tight')
print(f"\n低准确率样本总和对比图已保存: {sum_output}")
plt.close()

# 画综合对比图（1/8, 2/8, 3/8在一张图上）
plt.figure(figsize=(16, 7))
for idx, (dir_name, data) in enumerate(all_dir_data.items()):
    color = colors[idx % len(colors)]
    marker = markers[idx % len(markers)]
    smoothed_one = smooth_data(data['one_correct_counts'], window_size=3)
    smoothed_two = smooth_data(data['two_correct_counts'], window_size=3)
    smoothed_three = smooth_data(data['three_correct_counts'], window_size=3)
    plt.plot(data['steps'], smoothed_one, 
             marker=marker, linewidth=2.5, markersize=6, linestyle='-',
             label=f'{dir_name} (1/8)', color=color, alpha=0.9)
    plt.plot(data['steps'], smoothed_two, 
             marker=marker, linewidth=2.5, markersize=6, linestyle='--',
             label=f'{dir_name} (2/8)', color=color, alpha=0.75)
    plt.plot(data['steps'], smoothed_three, 
             marker=marker, linewidth=2.5, markersize=6, linestyle=':',
             label=f'{dir_name} (3/8)', color=color, alpha=0.6)
plt.xlabel('Training Step', fontsize=12, fontweight='bold')
plt.ylabel('Sample Count', fontsize=12, fontweight='bold')
plt.title('Low accuracy samples (1/8, 2/8, 3/8) comparison (smoothed)', fontsize=14, fontweight='bold')
plt.legend(fontsize=10, ncol=len(all_dir_data), framealpha=0.9)
plt.grid(True, alpha=0.3)
plt.tight_layout()
combined_output = f'train_data_analysis/{save_name}_low_accuracy_combined.png'
plt.savefig(combined_output, dpi=300, bbox_inches='tight')
print(f"综合对比图已保存: {combined_output}")

# 打印每个目录的详细统计
for dir_name, data in all_dir_data.items():
    step_stats = data['step_stats']
    dir_sorted_steps = sorted(step_stats.keys())
    
    print(f"\n{'='*120}")
    print(f"=== 详细统计 - {dir_name} ===")
    print(f"{'='*120}")
    print(f"{'Step':<8} {'总数':<8} {'全对(8/8)':<10} {'全错(0/8)':<10} {'1/8':<8} {'2/8':<8} {'3/8':<8} {'1-3总和':<10} {'异常数':<8} {'全对率':<10} {'全错率':<10}")
    print("-" * 120)
    for step in dir_sorted_steps:
        stats = step_stats[step]
        total = stats['total']
        all_correct = stats['all_correct']
        all_wrong = stats['all_wrong']
        one_correct = stats['accuracy_counts'].get((1, 8), 0)
        two_correct = stats['accuracy_counts'].get((2, 8), 0)
        three_correct = stats['accuracy_counts'].get((3, 8), 0)
        low_sum = one_correct + two_correct + three_correct
        anomaly_count = len(stats['anomaly_cases'])
        correct_rate = (all_correct / total * 100) if total > 0 else 0
        wrong_rate = (all_wrong / total * 100) if total > 0 else 0
        print(f"{step:<8} {total:<8} {all_correct:<10} {all_wrong:<10} {one_correct:<8} {two_correct:<8} {three_correct:<8} {low_sum:<10} {anomaly_count:<8} {correct_rate:>7.2f}% {wrong_rate:>9.2f}%")
    
    # 报告异常情况
    print(f"\n=== 异常情况检查 - {dir_name} (finish_reason='length' 但 accuracy=True) ===")
    total_anomalies = sum(len(step_stats[step]['anomaly_cases']) for step in dir_sorted_steps)
    if total_anomalies > 0:
        print(f"发现 {total_anomalies} 个异常case!\n")
        for step in dir_sorted_steps:
            anomalies = step_stats[step]['anomaly_cases']
            if anomalies:
                print(f"  Step {step}: 发现 {len(anomalies)} 个异常")
                anomaly_file = f'train_data_analysis/{dir_name}_step{step}_anomaly_cases.json'
                with open(anomaly_file, 'w', encoding='utf-8') as f:
                    json.dump(anomalies, f, ensure_ascii=False, indent=4)
    else:
        print("未发现异常情况 ✓")

print(f"\n{'='*120}")
print("所有分析完成！")
print(f"{'='*120}")
print("\n生成的主要对比图：")
print(f"  1. 低准确率样本总和(1/8+2/8+3/8): {sum_output}")
print(f"  2. 1/8样本对比: {one_output}")
print(f"  3. 2/8样本对比: {two_output}")
print(f"  4. 3/8样本对比: {three_output}")
print(f"  5. 综合详细对比: {combined_output}")
print("\n每个目录单独的图表也已生成在 train_data_analysis/ 目录中")