import json
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# 替换为您的 JSON 文件路径
file1_path = 'AI_School_main_vllm/eval/MATH_EVAL/MMLU_PRO/mistral_MA_Gen/LLAMA3_MATHQA_EVAL.json'  # 包含 <StudentA> 字段的 JSON 文件
file2_path = 'AI_School_main_vllm/eval/MATH_EVAL/ori_data_MMLU_PRO/mistral_ori_data/LLAMA3_MATHQA_EVAL.json'  # 包含 answer 字段的 JSON 文件

# 读取第一个 JSON 文件并提取 <StudentA> 字段内容
with open(file1_path, 'r', encoding='utf-8') as f1:
    data1 = json.load(f1)
    master_texts = [entry.get("<StudentA>", "") for entry in data1]

# 读取第二个 JSON 文件并提取 answer 字段内容
with open(file2_path, 'r', encoding='utf-8') as f2:
    data2 = json.load(f2)
    ori_texts = [entry.get("answer", "") for entry in data2]

# 使用正则表达式进行分词
def tokenize(text):
    return re.findall(r'\b\w+\b', text)

# 计算每个文本的 token 数量
master_tokens = [len(tokenize(text)) for text in master_texts]
ori_tokens = [len(tokenize(text)) for text in ori_texts]

# 限制 token 数最大为1000
max_token_limit = 1200
master_tokens_limited = [t if t <= max_token_limit else max_token_limit for t in master_tokens]
ori_tokens_limited = [t if t <= max_token_limit else max_token_limit for t in ori_tokens]

# 创建 DataFrame 以便于绘图
df = pd.DataFrame({
    'MASTER_Tokens': master_tokens_limited,
    'ORI_Tokens': ori_tokens_limited
})

# 计算平均 token 数（原始数据，不限制最大值）
avg_master = sum(master_tokens) / len(master_tokens)
avg_ori = sum(ori_tokens) / len(ori_tokens)

# 设置 Seaborn 样式
sns.set(style="whitegrid")

# 创建 KDE 图
plt.figure(figsize=(10, 6))
plt.rcParams['font.weight'] = 'bold'      # 加粗
plt.rcParams['font.size'] = 14            # 基础字体大小
plt.rcParams['legend.fontsize'] = 14      # 图例字体大小
plt.rcParams['axes.labelsize'] = 14       # 坐标轴标签大小
sns.kdeplot(df['MASTER_Tokens'], label='MASTER', fill=True, color='red', alpha=0.5)
sns.kdeplot(df['ORI_Tokens'], label='ORI', fill=True, color='orange', alpha=0.5)
plt.axvline(avg_master, color='red', linestyle='--', label=f'MASTER Avg: {avg_master:.2f}')
plt.axvline(avg_ori, color='orange', linestyle='--', label=f'ORI Avg: {avg_ori:.2f}')
plt.title('Token Count Distribution: MASTER vs ORI (Tokens capped at 1200)', fontsize=18, fontweight='bold')
plt.xlabel('Token Count', fontsize=16, fontweight='bold')
plt.ylabel('Density', fontsize=16, fontweight='bold')
plt.legend(fontsize=16, title='Legend', title_fontsize=16, frameon=True, shadow=True, fancybox=True)
plt.xticks(fontsize=14, fontweight='bold')
plt.yticks(fontsize=14, fontweight='bold')
plt.xlim(0, 1200)  # 限制x轴显示范围
plt.legend()
plt.tight_layout()

# 保存图像为 PDF 文件
plt.savefig("token_density_plot_capped.pdf", format="pdf", bbox_inches="tight")

# 显示图像
plt.show()

# 输出平均 token 数
print(f'平均 token 数 - MASTER: {avg_master:.2f}')
print(f'平均 token 数 - ORI: {avg_ori:.2f}')