import json
from collections import Counter, defaultdict
import os
import sys
import glob
import re
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer
import torch
from tqdm import tqdm

# 检查参数
if len(sys.argv) < 3:
    print("用法: python analysis_length.py <save_name> <directory_path1:label1> [<directory_path2:label2> ...]")
    print("示例: python analysis_length.py output dir1:baseline dir2:experiment")
    exit(1)

save_name = sys.argv[1]
dir_configs = sys.argv[2:]

# 解析目录配置
directories = []
labels = []
for config in dir_configs:
    if ':' in config:
        dir_path, label = config.split(':', 1)
    else:
        dir_path = config
        label = os.path.basename(dir_path)
    
    if not os.path.exists(dir_path):
        print(f"错误：目录不存在: {dir_path}")
        exit(1)
    
    directories.append(dir_path)
    labels.append(label)

print(f"将处理 {len(directories)} 个目录")

# 加载tokenizer
print("正在加载tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-4B-Base")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")

# 批量tokenize函数，使用CUDA加速
def batch_tokenize(texts, batch_size=128):
    """批量tokenize文本，返回每个文本的token长度"""
    lengths = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing", leave=False):
        batch = texts[i:i+batch_size]
        # 使用tokenizer批量处理，不添加special tokens以获得实际内容长度
        encodings = tokenizer(batch, padding=False, truncation=False, add_special_tokens=False)
        # 计算每个文本的长度
        batch_lengths = [len(ids) for ids in encodings['input_ids']]
        lengths.extend(batch_lengths)
    return lengths

def process_directory(directory_path, label):
    """处理单个目录，返回该目录的统计数据"""
    print(f"\n{'='*60}")
    print(f"处理目录: {directory_path} (标签: {label})")
    print(f"{'='*60}")
    
    # 获取所有的step文件
    step_files = glob.glob(os.path.join(directory_path, "step_*_traindata.jsonl"))
    if not step_files:
        print(f"警告：在目录 {directory_path} 中没有找到step_*_traindata.jsonl文件")
        return None
    
    print(f"找到 {len(step_files)} 个step文件")
    
    # 存储每个step的统计数据
    step_stats = {}
    
    for file_path in sorted(step_files):
        # 从文件名提取step数字
        filename = os.path.basename(file_path)
        match = re.search(r'step_(\d+)_traindata\.jsonl', filename)
        if not match:
            continue
        
        step_num = int(match.group(1))
        print(f"  处理 Step {step_num}...")
        
        # 收集该step的所有响应文本
        correct_texts = []
        incorrect_texts = []
        total_items = 0
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    try:
                        item = json.loads(line.strip())
                        accuracies = item['accuracies']
                        responses = item['responses']
                        
                        # 分类收集响应文本
                        for acc, response in zip(accuracies, responses):
                            if acc:
                                correct_texts.append(response)
                            else:
                                incorrect_texts.append(response)
                        
                        total_items += 1
                        
                    except Exception as e:
                        print(f"    警告：处理行 {line_num} 失败: {e}")
                        continue
            
            print(f"    收集到 {len(correct_texts)} 个正确响应, {len(incorrect_texts)} 个错误响应")
            
            # 批量计算token长度
            correct_lengths = batch_tokenize(correct_texts) if correct_texts else []
            incorrect_lengths = batch_tokenize(incorrect_texts) if incorrect_texts else []
            
            # 统计信息
            step_stats[step_num] = {
                'total_items': total_items,
                'correct_count': len(correct_texts),
                'incorrect_count': len(incorrect_texts),
                'correct_lengths': correct_lengths,
                'incorrect_lengths': incorrect_lengths,
                'correct_mean': np.mean(correct_lengths) if correct_lengths else 0,
                'correct_median': np.median(correct_lengths) if correct_lengths else 0,
                'correct_std': np.std(correct_lengths) if correct_lengths else 0,
                'incorrect_mean': np.mean(incorrect_lengths) if incorrect_lengths else 0,
                'incorrect_median': np.median(incorrect_lengths) if incorrect_lengths else 0,
                'incorrect_std': np.std(incorrect_lengths) if incorrect_lengths else 0,
            }
            
            print(f"    正确响应平均长度: {step_stats[step_num]['correct_mean']:.1f} tokens")
            print(f"    错误响应平均长度: {step_stats[step_num]['incorrect_mean']:.1f} tokens")
            
        except Exception as e:
            print(f"    警告：处理文件 {file_path} 失败: {e}")
    
    return step_stats

# 处理所有目录
all_dir_stats = {}
for directory, label in zip(directories, labels):
    stats = process_directory(directory, label)
    if stats:
        all_dir_stats[label] = stats

# 检查是否有有效数据
if not all_dir_stats:
    print("错误：没有找到有效的数据")
    exit(1)

# 创建输出目录
os.makedirs('train_data_analysis', exist_ok=True)

# 定义颜色和标记
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p']
linestyles = ['-', '--', '-.', ':']

# 画图1: 平均长度对比 (正确响应)
plt.figure(figsize=(14, 7))

for idx, (label, step_stats) in enumerate(all_dir_stats.items()):
    sorted_steps = sorted(step_stats.keys())
    steps = []
    correct_means = []
    
    for step in sorted_steps:
        steps.append(step)
        correct_means.append(step_stats[step]['correct_mean'])
    
    color = colors[idx % len(colors)]
    marker = markers[idx % len(markers)]
    linestyle = linestyles[idx % len(linestyles)]
    
    plt.plot(steps, correct_means, marker=marker, linewidth=2, markersize=6, 
             label=f'{label} - Correct', color=color, linestyle=linestyle, alpha=0.8)

plt.xlabel('Training Step', fontsize=12)
plt.ylabel('Average Token Length', fontsize=12)
plt.title('Average correct response length during training', fontsize=14)
plt.legend(fontsize=10, loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()

mean_correct_output = f'train_data_analysis/{save_name}_length_mean_correct.png'
plt.savefig(mean_correct_output, dpi=300, bbox_inches='tight')
print(f"\n正确响应平均长度图已保存到: {mean_correct_output}")
plt.close()

# 画图2: 平均长度对比 (错误响应)
plt.figure(figsize=(14, 7))

for idx, (label, step_stats) in enumerate(all_dir_stats.items()):
    sorted_steps = sorted(step_stats.keys())
    steps = []
    incorrect_means = []
    
    for step in sorted_steps:
        steps.append(step)
        incorrect_means.append(step_stats[step]['incorrect_mean'])
    
    color = colors[idx % len(colors)]
    marker = markers[idx % len(markers)]
    linestyle = linestyles[idx % len(linestyles)]
    
    plt.plot(steps, incorrect_means, marker=marker, linewidth=2, markersize=6, 
             label=f'{label} - Incorrect', color=color, linestyle=linestyle, alpha=0.8)

plt.xlabel('Training Step', fontsize=12)
plt.ylabel('Average Token Length', fontsize=12)
plt.title('Average incorrect response length during training', fontsize=14)
plt.legend(fontsize=10, loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()

mean_incorrect_output = f'train_data_analysis/{save_name}_length_mean_incorrect.png'
plt.savefig(mean_incorrect_output, dpi=300, bbox_inches='tight')
print(f"错误响应平均长度图已保存到: {mean_incorrect_output}")
plt.close()

# 画图3: 中位数长度对比 (正确响应)
plt.figure(figsize=(14, 7))

for idx, (label, step_stats) in enumerate(all_dir_stats.items()):
    sorted_steps = sorted(step_stats.keys())
    steps = []
    correct_medians = []
    
    for step in sorted_steps:
        steps.append(step)
        correct_medians.append(step_stats[step]['correct_median'])
    
    color = colors[idx % len(colors)]
    marker = markers[idx % len(markers)]
    linestyle = linestyles[idx % len(linestyles)]
    
    plt.plot(steps, correct_medians, marker=marker, linewidth=2, markersize=6, 
             label=f'{label} - Correct', color=color, linestyle=linestyle, alpha=0.8)

plt.xlabel('Training Step', fontsize=12)
plt.ylabel('Median Token Length', fontsize=12)
plt.title('Median correct response length during training', fontsize=14)
plt.legend(fontsize=10, loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()

median_correct_output = f'train_data_analysis/{save_name}_length_median_correct.png'
plt.savefig(median_correct_output, dpi=300, bbox_inches='tight')
print(f"正确响应中位数长度图已保存到: {median_correct_output}")
plt.close()

# 画图4: 中位数长度对比 (错误响应)
plt.figure(figsize=(14, 7))

for idx, (label, step_stats) in enumerate(all_dir_stats.items()):
    sorted_steps = sorted(step_stats.keys())
    steps = []
    incorrect_medians = []
    
    for step in sorted_steps:
        steps.append(step)
        incorrect_medians.append(step_stats[step]['incorrect_median'])
    
    color = colors[idx % len(colors)]
    marker = markers[idx % len(markers)]
    linestyle = linestyles[idx % len(linestyles)]
    
    plt.plot(steps, incorrect_medians, marker=marker, linewidth=2, markersize=6, 
             label=f'{label} - Incorrect', color=color, linestyle=linestyle, alpha=0.8)

plt.xlabel('Training Step', fontsize=12)
plt.ylabel('Median Token Length', fontsize=12)
plt.title('Median incorrect response length during training', fontsize=14)
plt.legend(fontsize=10, loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()

median_incorrect_output = f'train_data_analysis/{save_name}_length_median_incorrect.png'
plt.savefig(median_incorrect_output, dpi=300, bbox_inches='tight')
print(f"错误响应中位数长度图已保存到: {median_incorrect_output}")
plt.close()

# 画图5: 正确响应数量
plt.figure(figsize=(14, 7))

for idx, (label, step_stats) in enumerate(all_dir_stats.items()):
    sorted_steps = sorted(step_stats.keys())
    steps = []
    correct_counts = []
    
    for step in sorted_steps:
        steps.append(step)
        correct_counts.append(step_stats[step]['correct_count'])
    
    color = colors[idx % len(colors)]
    marker = markers[idx % len(markers)]
    linestyle = linestyles[idx % len(linestyles)]
    
    plt.plot(steps, correct_counts, marker=marker, linewidth=2, markersize=6, 
             label=f'{label} - Correct', color=color, linestyle=linestyle, alpha=0.8)

plt.xlabel('Training Step', fontsize=12)
plt.ylabel('Response Count', fontsize=12)
plt.title('Number of correct responses during training', fontsize=14)
plt.legend(fontsize=10, loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()

count_correct_output = f'train_data_analysis/{save_name}_response_counts_correct.png'
plt.savefig(count_correct_output, dpi=300, bbox_inches='tight')
print(f"正确响应数量图已保存到: {count_correct_output}")
plt.close()

# 画图6: 错误响应数量
plt.figure(figsize=(14, 7))

for idx, (label, step_stats) in enumerate(all_dir_stats.items()):
    sorted_steps = sorted(step_stats.keys())
    steps = []
    incorrect_counts = []
    
    for step in sorted_steps:
        steps.append(step)
        incorrect_counts.append(step_stats[step]['incorrect_count'])
    
    color = colors[idx % len(colors)]
    marker = markers[idx % len(markers)]
    linestyle = linestyles[idx % len(linestyles)]
    
    plt.plot(steps, incorrect_counts, marker=marker, linewidth=2, markersize=6, 
             label=f'{label} - Incorrect', color=color, linestyle=linestyle, alpha=0.8)

plt.xlabel('Training Step', fontsize=12)
plt.ylabel('Response Count', fontsize=12)
plt.title('Number of incorrect responses during training', fontsize=14)
plt.legend(fontsize=10, loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()

count_incorrect_output = f'train_data_analysis/{save_name}_response_counts_incorrect.png'
plt.savefig(count_incorrect_output, dpi=300, bbox_inches='tight')
print(f"错误响应数量图已保存到: {count_incorrect_output}")
plt.close()

# 打印详细统计
print(f"\n{'='*80}")
print("详细统计")
print(f"{'='*80}")

for label, step_stats in all_dir_stats.items():
    print(f"\n【{label}】")
    print(f"{'Step':<8} {'总样本':<10} {'正确数':<10} {'错误数':<10} {'正确均长':<12} {'错误均长':<12} {'正确中位':<12} {'错误中位':<12}")
    print("-" * 110)
    
    sorted_steps = sorted(step_stats.keys())
    for step in sorted_steps:
        stats = step_stats[step]
        print(f"{step:<8} {stats['total_items']:<10} {stats['correct_count']:<10} {stats['incorrect_count']:<10} "
              f"{stats['correct_mean']:>10.1f} {stats['incorrect_mean']:>12.1f} "
              f"{stats['correct_median']:>12.1f} {stats['incorrect_median']:>12.1f}")

# 保存统计数据到JSON
stats_output = f'train_data_analysis/{save_name}_length_stats.json'
with open(stats_output, 'w', encoding='utf-8') as f:
    # 转换numpy类型为Python原生类型
    save_data = {}
    for label, step_stats in all_dir_stats.items():
        save_data[label] = {}
        for step, stats in step_stats.items():
            save_data[label][str(step)] = {
                'total_items': stats['total_items'],
                'correct_count': stats['correct_count'],
                'incorrect_count': stats['incorrect_count'],
                'correct_mean': float(stats['correct_mean']),
                'correct_median': float(stats['correct_median']),
                'correct_std': float(stats['correct_std']),
                'incorrect_mean': float(stats['incorrect_mean']),
                'incorrect_median': float(stats['incorrect_median']),
                'incorrect_std': float(stats['incorrect_std']),
            }
    json.dump(save_data, f, ensure_ascii=False, indent=4)
print(f"\n统计数据已保存到: {stats_output}")

print(f"\n{'='*80}")
print("分析完成！")
print(f"{'='*80}")
