import json
from transformers import AutoTokenizer
from utils import load_eval_data
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import json
from utils import load_eval_data
from deepscaler.rewards.math_utils.utils import extract_answer, grade_answer_sympy as grade_answer
import matplotlib.pyplot as plt
import seaborn as sns


# model_path = "Qwen/Qwen2.5-1.5B"
# tokenizer = AutoTokenizer.from_pretrained(model_path)

# data_path = "model_eval/Deepseek-Qwen-7B-merge-0.8-dpo-beta-0.1-no-ln-bilevel-fulldata-M1-4-M2-2/math.json"

# data = load_eval_data(data_path)

# total_accs = []
# total_lengths = []

# for group_index in range(len(data)):
# # for group_index in range(100):
# # 
#     group = data[group_index]

#     ground_truth_answer = group['reward_model']['ground_truth']

#     correctness = group['correctness']

#     solutions = [solution for solution in group['responses']]

#     solution_lengths = [len(tokenizer(solution)['input_ids']) for solution in solutions]

#     total_accs.extend(correctness)
#     total_lengths.extend(solution_lengths)

# # Plot a histogram of `lengths`
# # plt.hist(total_lengths, bins=20, density=True, alpha=0.7, color='blue')
# plt.figure(figsize=(10, 6)) # 可以调整图形大小
# sns.kdeplot(total_lengths, fill=True, color='skyblue')


# # Add labels and title
# plt.xlabel('Solution Lengths')
# plt.ylabel('Density')
# plt.title('Density Distribution of Solution Lengths')


import matplotlib.pyplot as plt
import seaborn as sns
import json
from transformers import AutoTokenizer
import os

def calculate_thinking_freq(solution):
    solution = solution.split("</think>")[0].lower()
    key_words = ["wait", "hmm", "remember", "recheck"]
    # calculate sum of count of key words in solution
    count = 0
    for key_word in key_words:
        count += solution.count(key_word)
    return count

def use_think(solution):
    solution = solution.split("</think>")[0].lower()
    key_words = ["wait", "hmm", "remember", "recheck"]
    # calculate sum of count of key words in solution
    count = 0
    for key_word in key_words:
        count += solution.count(key_word)
    return count > 0

# 定义模型路径
model_path = "Qwen/Qwen2.5-1.5B"

# 定义多个数据文件的路径列表
# 请将这里的路径替换为你实际的 JSON 文件路径
data_paths = [
    "model_eval/Deepseek-Qwen-7B-merge-0.8-dpo-beta-0.1-no-ln-bilevel-fulldata-M1-4-M2-2/math.json",
    "model_eval/Deepseek-Qwen-7B-dpo-epoch-1/math.json",
    "model_eval/DeepSeek-R1-Distill-Qwen-7B/math.json",
    "model_eval/7B_long_0.8_short_0.2/math.json"
    # "model_eval/Deepseek-Qwen-7B-dpo/math.json"
]

# 初始化 tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(model_path)
except Exception as e:
    print(f"Error loading tokenizer: {e}")
    print("Please ensure you have the 'transformers' library installed and the model path is correct.")
    exit()

# 创建图形和轴对象
plt.figure(figsize=(12, 8)) # 可以调整图形大小
ax = plt.gca() # 获取当前轴，以便在同一个轴上绘制

# 遍历每个数据文件并绘制 KDE
for data_path in data_paths:
    thinking_cot_correct = 0
    thinking_cot_count = 0
    non_thinking_cot_correct = 0
    non_thinking_cot_count = 0
    data = load_eval_data(data_path)

    # 计算当前数据集的解决方案长度
    current_is_thinking_cot = []
    current_correctness = []

    count = 0
    for group in data:
        solutions = group['responses']
        solutions = [solution.split("</think>")[0] for solution in solutions]
        correctness = group['correctness']
        # solution_lengths = [len(tokenizer(solution)['input_ids']) for solution in solutions]
        is_thinking_cot = [int(use_think(solution)) for solution in solutions]
        current_is_thinking_cot.extend(is_thinking_cot)
        current_correctness.extend(correctness)
        
        for is_think, correct in zip(is_thinking_cot, correctness):
            count += 1
            if is_think == 1:
                thinking_cot_count += 1
                thinking_cot_correct += correct
            else:
                non_thinking_cot_count += 1
                non_thinking_cot_correct += correct
    
    print(f"Dataset: {data_path}")
    print(f"Thinking COT ratio: {thinking_cot_count / count:.2f}")
    print(f"Non-thinking COT ratio: {non_thinking_cot_count / count:.2f}")
    print(f"Thinking COT accuracy: {thinking_cot_correct / thinking_cot_count:.2f}")
    print(f"Non-thinking COT accuracy: {non_thinking_cot_correct / non_thinking_cot_count:.2f}")
    print("-"*20)



    # 绘制当前数据集的 KDE
    if current_is_thinking_cot: # 确保有数据可绘制
        # 从文件路径提取文件名作为标签
        label = os.path.basename(data_path)
        # sns.kdeplot(current_lengths, ax=ax, fill=False, label=label, bw_adjust=0.2) # fill=False 不填充颜色
        plt.hist(current_is_thinking_cot, bins=2)

# 添加标题和标签
plt.title('Kernel Density Estimate of Solution Lengths for Multiple Datasets')
plt.xlabel('Solution Lengths')
plt.ylabel('Density')

# 添加图例
plt.legend(title="Datasets")

# 显示图形
# plt.grid(True, linestyle='--', alpha=0.6) # 添加网格线
plt.show()
plt.savefig("figs/lengths.png")

