import json
from collections import Counter, defaultdict
import sys
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
import torch
from tqdm import tqdm
import os

# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 检查参数
if len(sys.argv) < 2:
    print("用法: python analysis_models.py <jsonl_file_path> [batch_size]")
    print("示例: python analysis_models.py /path/to/file.jsonl 128")
    exit(1)

file_path = sys.argv[1]
batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 128

if not os.path.exists(file_path):
    print(f"错误：文件不存在: {file_path}")
    exit(1)

print(f"正在分析文件: {file_path}")
print(f"批处理大小: {batch_size}")

# 加载tokenizer
print("正在加载tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-4B-Base")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")

# 批量tokenize函数，使用CUDA加速
def batch_tokenize(texts, batch_size=128):
    """批量tokenize文本，返回每个文本的token长度"""
    if not texts:
        return []
    
    lengths = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing", leave=False):
        batch = texts[i:i+batch_size]
        # 使用tokenizer批量处理，不添加special tokens以获得实际内容长度
        encodings = tokenizer(batch, padding=False, truncation=False, add_special_tokens=False)
        # 计算每个文本的长度
        batch_lengths = [len(ids) for ids in encodings['input_ids']]
        lengths.extend(batch_lengths)
    return lengths

print("\n" + "="*80)
print("第一步：读取数据并分组...")
print("="*80)

# 读取所有数据
all_records = []
with open(file_path, 'r', encoding='utf-8') as f:
    for line_num, line in enumerate(tqdm(f, desc="读取文件"), 1):
        if line.strip():
            try:
                item = json.loads(line)
                all_records.append(item)
            except json.JSONDecodeError as e:
                print(f"警告：第 {line_num} 行JSON解析错误: {e}")
                continue

print(f"总共读取了 {len(all_records)} 条记录")

# 检查是否为32的倍数
if len(all_records) % 32 != 0:
    print(f"警告：记录总数 {len(all_records)} 不是32的倍数")
    print(f"将处理前 {(len(all_records) // 32) * 32} 条记录")
    all_records = all_records[:(len(all_records) // 32) * 32]

# 按每32条分组
num_groups = len(all_records) // 32
print(f"将数据分为 {num_groups} 组，每组32条记录")

print("\n" + "="*80)
print("第二步：统计n/32分布...")
print("="*80)

# 统计n/32分布
distribution_32 = Counter()

# 收集所有文本用于tokenize
correct_texts = []
incorrect_texts = []

for group_idx in tqdm(range(num_groups), desc="处理分组"):
    start_idx = group_idx * 32
    end_idx = start_idx + 32
    group_records = all_records[start_idx:end_idx]
    
    # 统计这一组中score=1的数量
    correct_count = 0
    for record in group_records:
        # 检查score字段
        if 'score' in record and record['score'] == 1:
            correct_count += 1
            if 'output' in record:
                correct_texts.append(record['output'])
        else:
            if 'output' in record:
                incorrect_texts.append(record['output'])
    
    distribution_32[correct_count] += 1

print(f"\n收集到 {len(correct_texts)} 个正确输出")
print(f"收集到 {len(incorrect_texts)} 个错误输出")

print("\n" + "="*80)
print("第三步：计算token长度...")
print("="*80)

# 批量计算token长度
print("正在计算正确输出的token长度...")
correct_lengths = batch_tokenize(correct_texts, batch_size) if correct_texts else []

print("正在计算错误输出的token长度...")
incorrect_lengths = batch_tokenize(incorrect_texts, batch_size) if incorrect_texts else []

print("\n" + "="*80)
print("分析结果")
print("="*80)

# 打印n/32分布
print("\n【n/32 分布统计】")
print(f"{'n/32':<10} {'数量':<10} {'百分比':<10}")
print("-" * 40)
for i in range(33):
    count = distribution_32.get(i, 0)
    percentage = (count / num_groups * 100) if num_groups > 0 else 0
    if count > 0:  # 只显示有数据的行
        print(f"{i}/32{'':<5} {count:<10} {percentage:>6.2f}%")

# 打印长度统计
print("\n【长度统计】")
if correct_lengths:
    print(f"\n正确输出 (总数: {len(correct_lengths)}):")
    print(f"  平均长度: {np.mean(correct_lengths):.2f} tokens")
    print(f"  中位数长度: {np.median(correct_lengths):.2f} tokens")
    print(f"  标准差: {np.std(correct_lengths):.2f} tokens")
    print(f"  最小长度: {np.min(correct_lengths)} tokens")
    print(f"  最大长度: {np.max(correct_lengths)} tokens")
    print(f"  25分位数: {np.percentile(correct_lengths, 25):.2f} tokens")
    print(f"  75分位数: {np.percentile(correct_lengths, 75):.2f} tokens")
else:
    print("\n正确输出: 无数据")

if incorrect_lengths:
    print(f"\n错误输出 (总数: {len(incorrect_lengths)}):")
    print(f"  平均长度: {np.mean(incorrect_lengths):.2f} tokens")
    print(f"  中位数长度: {np.median(incorrect_lengths):.2f} tokens")
    print(f"  标准差: {np.std(incorrect_lengths):.2f} tokens")
    print(f"  最小长度: {np.min(incorrect_lengths)} tokens")
    print(f"  最大长度: {np.max(incorrect_lengths)} tokens")
    print(f"  25分位数: {np.percentile(incorrect_lengths, 25):.2f} tokens")
    print(f"  75分位数: {np.percentile(incorrect_lengths, 75):.2f} tokens")
else:
    print("\n错误输出: 无数据")

print("\n" + "="*80)
print("第四步：生成可视化...")
print("="*80)

# 创建输出目录
output_dir = os.path.dirname(file_path)
file_basename = os.path.basename(file_path).replace('.jsonl', '')

# 1. n/32分布柱状图
plt.figure(figsize=(14, 6))
x_values = sorted([k for k in distribution_32.keys()])
y_values = [distribution_32[k] for k in x_values]

plt.bar(x_values, y_values, color='steelblue', alpha=0.7, edgecolor='black')
plt.xlabel('n/32 (正确答案数/总数)', fontsize=12)
plt.ylabel('数量', fontsize=12)
plt.title(f'n/32 分布 - {file_basename}', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')
plt.xticks(range(0, 33, 2))

# 在柱子上标注数值
for x, y in zip(x_values, y_values):
    if y > 0:
        plt.text(x, y, str(y), ha='center', va='bottom', fontsize=9)

plt.tight_layout()
dist_output = os.path.join(output_dir, f'{file_basename}_n32_distribution.png')
plt.savefig(dist_output, dpi=300, bbox_inches='tight')
print(f"n/32分布图已保存到: {dist_output}")
plt.close()

# 2. 长度分布对比箱线图
if correct_lengths or incorrect_lengths:
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # 正确输出的长度分布
    if correct_lengths:
        axes[0].hist(correct_lengths, bins=50, color='green', alpha=0.6, edgecolor='black')
        axes[0].axvline(np.mean(correct_lengths), color='red', linestyle='--', linewidth=2, label=f'平均值: {np.mean(correct_lengths):.1f}')
        axes[0].axvline(np.median(correct_lengths), color='blue', linestyle='--', linewidth=2, label=f'中位数: {np.median(correct_lengths):.1f}')
        axes[0].set_xlabel('Token长度', fontsize=12)
        axes[0].set_ylabel('数量', fontsize=12)
        axes[0].set_title(f'正确输出长度分布 (样本数={len(correct_lengths)})', fontsize=13, fontweight='bold')
        axes[0].legend(fontsize=10)
        axes[0].grid(True, alpha=0.3, axis='y')
    else:
        axes[0].text(0.5, 0.5, '无正确输出数据', ha='center', va='center', fontsize=14)
        axes[0].set_title('正确输出长度分布', fontsize=13, fontweight='bold')
    
    # 错误输出的长度分布
    if incorrect_lengths:
        axes[1].hist(incorrect_lengths, bins=50, color='red', alpha=0.6, edgecolor='black')
        axes[1].axvline(np.mean(incorrect_lengths), color='darkred', linestyle='--', linewidth=2, label=f'平均值: {np.mean(incorrect_lengths):.1f}')
        axes[1].axvline(np.median(incorrect_lengths), color='blue', linestyle='--', linewidth=2, label=f'中位数: {np.median(incorrect_lengths):.1f}')
        axes[1].set_xlabel('Token长度', fontsize=12)
        axes[1].set_ylabel('数量', fontsize=12)
        axes[1].set_title(f'错误输出长度分布 (样本数={len(incorrect_lengths)})', fontsize=13, fontweight='bold')
        axes[1].legend(fontsize=10)
        axes[1].grid(True, alpha=0.3, axis='y')
    else:
        axes[1].text(0.5, 0.5, '无错误输出数据', ha='center', va='center', fontsize=14)
        axes[1].set_title('错误输出长度分布', fontsize=13, fontweight='bold')
    
    plt.tight_layout()
    length_output = os.path.join(output_dir, f'{file_basename}_length_distribution.png')
    plt.savefig(length_output, dpi=300, bbox_inches='tight')
    print(f"长度分布图已保存到: {length_output}")
    plt.close()

# 3. 箱线图对比
if correct_lengths or incorrect_lengths:
    plt.figure(figsize=(10, 6))
    data_to_plot = []
    labels = []
    
    if correct_lengths:
        data_to_plot.append(correct_lengths)
        labels.append(f'正确输出\n(样本数={len(correct_lengths)})')
    
    if incorrect_lengths:
        data_to_plot.append(incorrect_lengths)
        labels.append(f'错误输出\n(样本数={len(incorrect_lengths)})')
    
    if data_to_plot:
        bp = plt.boxplot(data_to_plot, labels=labels, patch_artist=True, 
                         notch=True, showmeans=True, meanline=True)
        
        # 设置颜色
        colors = ['lightgreen', 'lightcoral']
        for patch, color in zip(bp['boxes'], colors[:len(data_to_plot)]):
            patch.set_facecolor(color)
        
        plt.ylabel('Token长度', fontsize=12)
        plt.title(f'正确输出 vs 错误输出 长度对比 - {file_basename}', fontsize=14, fontweight='bold')
        plt.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()
        
        boxplot_output = os.path.join(output_dir, f'{file_basename}_length_boxplot.png')
        plt.savefig(boxplot_output, dpi=300, bbox_inches='tight')
        print(f"箱线图已保存到: {boxplot_output}")
        plt.close()

print("\n" + "="*80)
print("第五步：保存详细统计数据...")
print("="*80)

# 保存详细统计数据到JSON
stats_data = {
    'file_info': {
        'file_path': file_path,
        'total_records': len(all_records),
        'num_groups': num_groups,
    },
    'n32_distribution': {str(k): v for k, v in sorted(distribution_32.items())},
    'correct_outputs': {
        'count': len(correct_lengths),
        'mean': float(np.mean(correct_lengths)) if correct_lengths else 0,
        'median': float(np.median(correct_lengths)) if correct_lengths else 0,
        'std': float(np.std(correct_lengths)) if correct_lengths else 0,
        'min': int(np.min(correct_lengths)) if correct_lengths else 0,
        'max': int(np.max(correct_lengths)) if correct_lengths else 0,
        'percentile_25': float(np.percentile(correct_lengths, 25)) if correct_lengths else 0,
        'percentile_75': float(np.percentile(correct_lengths, 75)) if correct_lengths else 0,
    },
    'incorrect_outputs': {
        'count': len(incorrect_lengths),
        'mean': float(np.mean(incorrect_lengths)) if incorrect_lengths else 0,
        'median': float(np.median(incorrect_lengths)) if incorrect_lengths else 0,
        'std': float(np.std(incorrect_lengths)) if incorrect_lengths else 0,
        'min': int(np.min(incorrect_lengths)) if incorrect_lengths else 0,
        'max': int(np.max(incorrect_lengths)) if incorrect_lengths else 0,
        'percentile_25': float(np.percentile(incorrect_lengths, 25)) if incorrect_lengths else 0,
        'percentile_75': float(np.percentile(incorrect_lengths, 75)) if incorrect_lengths else 0,
    }
}

stats_output = os.path.join(output_dir, f'{file_basename}_statistics.json')
with open(stats_output, 'w', encoding='utf-8') as f:
    json.dump(stats_data, f, ensure_ascii=False, indent=4)
print(f"统计数据已保存到: {stats_output}")

print("\n" + "="*80)
print("分析完成！")
print("="*80)
print(f"\n生成的文件:")
print(f"  1. n/32分布图: {dist_output}")
if correct_lengths or incorrect_lengths:
    print(f"  2. 长度分布图: {length_output}")
    print(f"  3. 箱线图: {boxplot_output}")
print(f"  4. 统计数据: {stats_output}")

