import json
from transformers import AutoTokenizer
from utils import load_eval_data
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import json
from utils import load_eval_data
from deepscaler.rewards.math_utils.utils import extract_answer, grade_answer_sympy as grade_answer
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import seaborn as sns
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import torch
import random

def calculate_thinking_freq(solution):
    solution = solution.split("</think>")[0].lower()
    key_words = ["wait", "hmm", "remember", "recheck"]
    # calculate sum of count of key words in solution
    count = 0
    for key_word in key_words:
        count += solution.count(key_word)
    return count

def use_think(solution):
    solution = solution.split("</think>")[0].lower()
    key_words = ["wait", "hmm", "remember", "recheck"]
    # calculate sum of count of key words in solution
    count = 0
    for key_word in key_words:
        count += solution.count(key_word)
    return count > 0

def is_long_cot(prompt, average_length, reference_data):
    prompt_content = prompt[-1]['content']
    ref_average_length = reference_data[prompt_content]
    if average_length > 2.5 * ref_average_length:
        return True

# 定义模型路径
model_path = "Qwen/Qwen2.5-1.5B"

# 定义多个数据文件的路径列表
# 请将这里的路径替换为你实际的 JSON 文件路径
data_paths = [
    "model_eval/Deepseek-Qwen-7B-merge-0.8-dpo-beta-0.1-no-ln-bilevel-fulldata-M1-4-M2-2/math.json",
    "model_eval/Deepseek-Qwen-7B-dpo-epoch-1/math.json",
    "model_eval/DeepSeek-R1-Distill-Qwen-7B/math.json",
    "model_eval/Deepseek-Qwen-7B-o1pruner-alpha-5-MIX_MATH/math.json",
    "model_eval/Deepseek-Qwen-7B-Short-COT/math.json",
    # "model_eval/7B_long_0.8_short_0.2/math.json"
]

model_paths = [
    "models/Deepseek-Qwen-7B/Deepseek-Qwen-7B-merge-0.8-dpo-beta-0.1-no-ln-bilevel-fulldata-M1-4-M2-2",
    # "models/Deepseek-Qwen-7B/Deepseek-Qwen-7B-dpo/checkpoint-250",
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    # "models/Deepseek-Qwen-7B/long_0.8_short_0.2"
]

base_model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"#"models/Deepseek-Qwen-7B/long_0.8_short_0.2" #

tokenizer = AutoTokenizer.from_pretrained(model_paths[0])
raw_reference_data = load_dataset("HuggingFaceH4/MATH-500")['test']
reference_data = {}

for item in raw_reference_data:
    problem = item['problem']
    level = item['level']
    reference_data[problem] = level

for data_path in data_paths:
    # model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()
    tokenizer.padding_side = "right"
    thinking_cot_correct = 0
    thinking_cot_count = 0
    non_thinking_cot_correct = 0
    non_thinking_cot_count = 0
    print(data_path)
    data = load_eval_data(data_path)
    # 计算当前数据集的解决方案长度
    current_is_thinking_cot = []
    current_correctness = []
    count = 0

    
    level_analysis_data = []
    choose_index = 2
    for group in data:
        
        prompt = group['prompt'][-1]['content']
        solutions = group['responses']
        mid_solutions = [solution.split("</think>")[0] for solution in solutions]
        correctness = group['correctness'][choose_index]
        is_thinking_cot = [int(use_think(solution)) for solution in mid_solutions][choose_index]

        solutions_lengths = [len(tokenizer(solution)['input_ids']) for solution in solutions][choose_index]

        

        for problem in list(reference_data.keys()):
            if problem in prompt:
                level = reference_data[problem]
                break
        
        level_analysis_data.append(
            {
                "correctness": correctness,
                "length": solutions_lengths,
                "is_thinking_cot": is_thinking_cot,
                "level": level
            }
        )
        
    level_analysis_summary = {}

    for item in level_analysis_data:
        level = item['level']
        is_thinking_cot = item['is_thinking_cot']
        correctness = item['correctness'] # Get correctness

        # Initialize the level in the dictionary if not present
        if level not in level_analysis_summary:
            level_analysis_summary[level] = {
                'thinking_cot': 0,
                'non_thinking_cot': 0,
                'total_entries': 0,
                'sum_correctness': 0,
                'sum_tokens': 0
            }

        # Increment the appropriate COT counter
        if is_thinking_cot == 1:
            level_analysis_summary[level]['thinking_cot'] += 1
        else:
            level_analysis_summary[level]['non_thinking_cot'] += 1

        # Increment total entries and add correctness
        level_analysis_summary[level]['total_entries'] += 1
        level_analysis_summary[level]['sum_correctness'] += correctness # Add correctness
        level_analysis_summary[level]['sum_tokens'] += item['length']

    # Calculate and print proportions and average correctness for each level
    print("不同 level 中 thinking_cot/non_thinking_cot 比例及平均 correctness:")
    for level, counts_and_sums in level_analysis_summary.items():
        total_count = counts_and_sums['total_entries']
        thinking_count = counts_and_sums['thinking_cot']
        non_thinking_count = counts_and_sums['non_thinking_cot']
        sum_correctness = counts_and_sums['sum_correctness']

        print(f"\nLevel '{level}':")

        if total_count > 0:
            thinking_proportion = thinking_count / total_count
            non_thinking_proportion = non_thinking_count / total_count
            average_correctness = sum_correctness / total_count

            # print(f"Thinking COT 比例: {thinking_proportion:.2f} ({thinking_count}/{total_count})", end="")
            # print(f" Non-Thinking COT 比例: {non_thinking_proportion:.2f} ({non_thinking_count}/{total_count})", end="")
            print(f"  平均 Correctness: {average_correctness:.2f} ({sum_correctness}/{total_count})", end="")
            print(f"  平均 Token 数量: {counts_and_sums['sum_tokens'] / total_count:.2f} ({counts_and_sums['sum_tokens']}/{total_count})")
        else:
            print(f"  无数据。")