# import json
# import re
# import pandas as pd
# import numpy as np
# import seaborn as sns
# import matplotlib.pyplot as plt

# # 全局设置字体加粗加大
# plt.rcParams['font.weight'] = 'bold'       # 加粗
# plt.rcParams['font.size'] = 12             # 基础字体大小
# plt.rcParams['axes.titlesize'] = 14        # 标题字体大小
# plt.rcParams['axes.labelsize'] = 12        # 坐标轴标签大小
# plt.rcParams['xtick.labelsize'] = 10       # X轴刻度字体大小
# plt.rcParams['ytick.labelsize'] = 10       # Y轴刻度字体大小
# plt.rcParams['legend.fontsize'] = 10       # 图例字体大小
# plt.rcParams['figure.titlesize'] = 14      # 图形标题大小

# # 文件路径
# master_file = 'AI_School_main_vllm/eval/MATH_EVAL/MMLU_PRO/llama3_MA_Gen_data/LLAMA3_MATHQA_EVAL.json'
# ori_file = 'AI_School_main_vllm/eval/MATH_EVAL/ori_data_MMLU_PRO/llama3_ori_data/LLAMA3_MATHQA_EVAL.json'
# correct_file = 'AI_School_main_vllm/data/eval_datasets/MMLU_PRO/data/math.json'

# # 正则表达式匹配答案
# pattern = re.compile(r'[\(\[\{]\s*(\d{1,2}|[A-Z])', re.IGNORECASE)

# def extract_answer(text):
#     matches = pattern.findall(text)
#     if matches:
#         last = matches[-1].strip()
#         if last.isdigit():
#             num = int(last)
#             if 1 <= num <= 26:
#                 return chr(num + 64)
#         else:
#             return last.upper()
#     return None

# def tokenize(text):
#     return re.findall(r'\b\w+\b', text)

# # 读取数据
# with open(master_file, 'r', encoding='utf-8') as f:
#     master_data = json.load(f)

# with open(ori_file, 'r', encoding='utf-8') as f:
#     ori_data = json.load(f)

# with open(correct_file, 'r', encoding='utf-8') as f:
#     correct_data = json.load(f)

# correct_answers = [entry.get('answer', '').strip().upper() for entry in correct_data]

# # 构建 DataFrame
# records = []
# for i, (m_entry, o_entry, correct) in enumerate(zip(master_data, ori_data, correct_answers)):
#     m_text = m_entry.get('<StudentA>', '')
#     o_text = o_entry.get('answer', '')
#     m_ans = extract_answer(m_text)
#     o_ans = extract_answer(o_text)
#     m_tokens = len(tokenize(m_text))
#     o_tokens = len(tokenize(o_text))
#     m_correct = m_ans == correct[-1].upper() if m_ans else False
#     o_correct = o_ans == correct[-1].upper() if o_ans else False
#     records.append({
#         'Model': 'MASTER',
#         'TokenLength': m_tokens,
#         'Correct': m_correct
#     })
#     records.append({
#         'Model': 'ORI',
#         'TokenLength': o_tokens,
#         'Correct': o_correct
#     })

# df = pd.DataFrame(records)

# # 设置 token 长度上限
# df['TokenLength'] = df['TokenLength'].clip(upper=1200)

# # 分 bin
# df['TokenBin'] = pd.cut(df['TokenLength'], bins=np.arange(0, 750, 50))

# # 计算每个 bin 的准确率
# heatmap_data = df.groupby(['Model', 'TokenBin'])['Correct'].mean().reset_index()
# heatmap_pivot = heatmap_data.pivot(index='Model', columns='TokenBin', values='Correct')

# # 绘制热力图
# plt.figure(figsize=(12, 6))
# sns.heatmap(heatmap_pivot, annot=True, fmt=".2f", cmap='YlGnBu')
# plt.title('Token Length vs Accuracy Heatmap', fontsize=18, fontweight='bold')
# plt.xlabel('Token Length Bin', fontsize=12, fontweight='bold')
# plt.ylabel('Model', fontsize=12, fontweight='bold')
# plt.tight_layout()
# plt.savefig("token_accuracy_heatmap.pdf")
# plt.show()

import json
import re
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# 全局设置字体加粗加大
plt.rcParams['font.weight'] = 'bold'       # 加粗
plt.rcParams['font.size'] = 12             # 基础字体大小
plt.rcParams['axes.titlesize'] = 14        # 标题字体大小
plt.rcParams['axes.labelsize'] = 12        # 坐标轴标签大小
plt.rcParams['xtick.labelsize'] = 10       # X轴刻度字体大小
plt.rcParams['ytick.labelsize'] = 10       # Y轴刻度字体大小
plt.rcParams['legend.fontsize'] = 10       # 图例字体大小
plt.rcParams['figure.titlesize'] = 14      # 图形标题大小

# 文件路径
master_file = 'AI_School_main_vllm/eval/MATH_EVAL/MMLU_PRO/llama3_MA_Gen_data/LLAMA3_MATHQA_EVAL.json'
ori_file = 'AI_School_main_vllm/eval/MATH_EVAL/ori_data_MMLU_PRO/llama3_ori_data/LLAMA3_MATHQA_EVAL.json'
correct_file = 'AI_School_main_vllm/data/eval_datasets/MMLU_PRO/data/math.json'

# 正则表达式匹配答案
pattern = re.compile(r'[\(\[\{]\s*(\d{1,2}|[A-Z])', re.IGNORECASE)

def extract_answer(text):
    matches = pattern.findall(text)
    if matches:
        last = matches[-1].strip()
        if last.isdigit():
            num = int(last)
            if 1 <= num <= 26:
                return chr(num + 64)
        else:
            return last.upper()
    return None

def tokenize(text):
    return re.findall(r'\b\w+\b', text)

# 读取数据
with open(master_file, 'r', encoding='utf-8') as f:
    master_data = json.load(f)

with open(ori_file, 'r', encoding='utf-8') as f:
    ori_data = json.load(f)

with open(correct_file, 'r', encoding='utf-8') as f:
    correct_data = json.load(f)

correct_answers = [entry.get('answer', '').strip().upper() for entry in correct_data]

# 构建 DataFrame
records = []
for i, (m_entry, o_entry, correct) in enumerate(zip(master_data, ori_data, correct_answers)):
    m_text = m_entry.get('<StudentA>', '')
    o_text = o_entry.get('answer', '')
    m_tokens = len(tokenize(m_text))
    o_tokens = len(tokenize(o_text))
    records.append({
        'Model': 'MASTER',
        'TokenLength': m_tokens
    })
    records.append({
        'Model': 'ORI',
        'TokenLength': o_tokens
    })

df = pd.DataFrame(records)

# 设置 token 长度上限
df['TokenLength'] = df['TokenLength'].clip(upper=1200)

# 分 bin
df['TokenBin'] = pd.cut(df['TokenLength'], bins=np.arange(0, 750, 50), right=False)

# 计算每个 bin 的样本数量
count_data = df.groupby(['Model', 'TokenBin']).size().reset_index(name='Count')
count_pivot = count_data.pivot(index='Model', columns='TokenBin', values='Count').fillna(0)

# 绘制热力图
plt.figure(figsize=(12, 6))
sns.heatmap(count_pivot, annot=True, fmt=".0f", cmap='YlGnBu')
plt.title('Token Length vs Sample Count Heatmap', fontsize=18, fontweight='bold')
plt.xlabel('Token Length Bin', fontsize=12, fontweight='bold')
plt.ylabel('Model', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.savefig("token_count_heatmap.pdf")
plt.show()
