import json
from transformers import AutoTokenizer
from utils import load_eval_data
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import json
from utils import load_eval_data
from deepscaler.rewards.math_utils.utils import extract_answer, grade_answer_sympy as grade_answer
import matplotlib.pyplot as plt
import seaborn as sns


# model_path = "Qwen/Qwen2.5-1.5B"
# tokenizer = AutoTokenizer.from_pretrained(model_path)

# data_path = "model_eval/Deepseek-Qwen-7B-merge-0.8-dpo-beta-0.1-no-ln-bilevel-fulldata-M1-4-M2-2/math.json"

# data = load_eval_data(data_path)

# total_accs = []
# total_lengths = []

# for group_index in range(len(data)):
# # for group_index in range(100):
# # 
#     group = data[group_index]

#     ground_truth_answer = group['reward_model']['ground_truth']

#     correctness = group['correctness']

#     solutions = [solution for solution in group['responses']]

#     solution_lengths = [len(tokenizer(solution)['input_ids']) for solution in solutions]

#     total_accs.extend(correctness)
#     total_lengths.extend(solution_lengths)

# # Plot a histogram of `lengths`
# # plt.hist(total_lengths, bins=20, density=True, alpha=0.7, color='blue')
# plt.figure(figsize=(10, 6)) # 可以调整图形大小
# sns.kdeplot(total_lengths, fill=True, color='skyblue')


# # Add labels and title
# plt.xlabel('Solution Lengths')
# plt.ylabel('Density')
# plt.title('Density Distribution of Solution Lengths')

import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import seaborn as sns
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import torch
import random

def calculate_thinking_freq(solution):
    solution = solution.split("</think>")[0].lower()
    key_words = ["wait", "hmm", "remember", "recheck"]
    # calculate sum of count of key words in solution
    count = 0
    for key_word in key_words:
        count += solution.count(key_word)
    return count

def use_think(solution):
    solution = solution.split("</think>")[0].lower()
    key_words = ["wait", "hmm", "remember", "recheck"]
    # calculate sum of count of key words in solution
    count = 0
    for key_word in key_words:
        count += solution.count(key_word)
    return count > 0

def is_long_cot(prompt, average_length, reference_data):
    prompt_content = prompt[-1]['content']
    ref_average_length = reference_data[prompt_content]
    if average_length > 2.5 * ref_average_length:
        return True

# 定义模型路径
model_path = "Qwen/Qwen2.5-1.5B"

# 定义多个数据文件的路径列表
# 请将这里的路径替换为你实际的 JSON 文件路径
data_paths = [
    "model_eval/Deepseek-Qwen-7B-merge-0.8-dpo-beta-0.1-no-ln-bilevel-fulldata-M1-4-M2-2/math.json",
    # "model_eval/Deepseek-Qwen-7B-dpo-epoch-1/math.json",
    # "model_eval/DeepSeek-R1-Distill-Qwen-7B/math.json",
    # "model_eval/7B_long_0.8_short_0.2/math.json"
]

model_paths = [
    "models/Deepseek-Qwen-7B/Deepseek-Qwen-7B-merge-0.8-dpo-beta-0.1-no-ln-bilevel-fulldata-M1-4-M2-2",
    # "models/Deepseek-Qwen-7B/Deepseek-Qwen-7B-dpo/checkpoint-250",
    # "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    # "models/Deepseek-Qwen-7B/long_0.8_short_0.2"
]

tokenizer = AutoTokenizer.from_pretrained(model_paths[0])
raw_reference_data = load_eval_data("model_eval/Deepseek-Qwen-7B-Short-COT/math.json")
reference_data = {}

for group in raw_reference_data:
    solutions = group['responses']
    correctness = group['correctness']
    solution_lengths = [len(tokenizer(solution)['input_ids']) for solution in solutions]
    average_length = sum(solution_lengths) / len(solution_lengths)
    reference_data[group['prompt'][-1]['content']] = average_length



# 遍历每个数据文件并绘制 KDE
for data_path, model_path in zip(data_paths, model_paths):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()
    tokenizer.padding_side = "right"
    thinking_cot_correct = 0
    thinking_cot_count = 0
    non_thinking_cot_correct = 0
    non_thinking_cot_count = 0
    data = load_eval_data(data_path)
    # 计算当前数据集的解决方案长度
    current_is_thinking_cot = []
    current_correctness = []
    count = 0

    all_prompts = []
    for group in data:
        solutions = group['responses']
        prompt = group['prompt']
        prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
        all_prompts.append(prompt)
    
    print(all_prompts[0])
    
    batch_size = 10
    all_selected_hiddens = []

    for i in range(0, len(all_prompts), batch_size):
        batch_prompts = all_prompts[i:i+batch_size]
        input_encodings = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False).to("cuda")
        attention_mask = input_encodings['attention_mask']
        input_lengths = attention_mask.sum(dim=1)
        # print(input_lengths)

        with torch.no_grad():
            outputs = model(**input_encodings, output_hidden_states=True)

        # selected_hiddens = None
        # for i in range(1,3):
        #     hiddens = outputs.hidden_states[-i]
        #     last_token_indices = input_lengths - 1
        #     batch_indices = torch.arange(hiddens.shape[0], device=hiddens.device) # Ensure indices are on the same device as hiddens
        #     if selected_hiddens is None:
        #         selected_hiddens = hiddens[batch_indices, last_token_indices, :]
        #     else:
        #         selected_hiddens += hiddens[batch_indices, last_token_indices, :]

        # selected_hiddens = selected_hiddens.float().cpu().numpy()

        hiddens = outputs.hidden_states[-1]
        last_token_indices = input_lengths - 1
        batch_indices = torch.arange(hiddens.shape[0], device=hiddens.device) # Ensure indices are on the same device as hiddens
        selected_hiddens = hiddens[batch_indices, last_token_indices, :]
        # print(selected_hiddens.shape)
        selected_hiddens = selected_hiddens.float().cpu().numpy()
        all_selected_hiddens.extend(selected_hiddens.tolist())
    

    for group in data:
        prompt = group['prompt']
        solutions = group['responses']
        solutions = [solution.split("</think>")[0] for solution in solutions]
        correctness = group['correctness']
        solution_lengths = [len(tokenizer(solution)['input_ids']) for solution in solutions]
        average_length = sum(solution_lengths) / len(solution_lengths)
        is_thinking_cot = [int(use_think(solution)) for solution in solutions]
        # print(is_thinking_cot,sum(is_thinking_cot) > int(len(solutions) / 2))
        if sum(is_thinking_cot) > int(len(solutions) / 2):
            is_thinking_cot = [1]
        
        elif sum(is_thinking_cot) < int(len(solutions) / 2):
            is_thinking_cot = [0]
        
        else:
            is_thinking_cot = [random.choice([0, 1])]

        # is_thinking_cot = [is_long_cot(prompt, average_length, reference_data) for _ in range(len(solutions))]
        current_is_thinking_cot.extend(is_thinking_cot)
        current_correctness.extend(correctness)

        for is_think, correct in zip(is_thinking_cot, correctness):
            count += 1
            if is_think == 1:
                thinking_cot_count += 1
                thinking_cot_correct += correct
            else:
                non_thinking_cot_count += 1
                non_thinking_cot_correct += correct

    point_colors = ['red' if label == 1 else 'blue' for label in current_is_thinking_cot]
    X = np.stack(all_selected_hiddens, axis=0)

    tsne = TSNE(n_components=2,
                    perplexity=200,    # Adjust if needed (e.g., lower for smaller datasets)
                    n_iter=2000,      # Can increase if optimization doesn't converge
                    init='pca',       # Often a good starting point
                    learning_rate='auto',
                    random_state=42)  # Use a fixed seed for reproducibility
    results = tsne.fit_transform(X)

    # pca = PCA(n_components=2)
    # results = pca.fit_transform(X)

    plt.figure(figsize=(12, 10)) # Adjust figure size as needed


    scatter = plt.scatter(
        results[:, 0], # X coordinates from t-SNE
        results[:, 1], # Y coordinates from t-SNE
        alpha=0.7,          # Point transparency
        s=40,
        c=point_colors
    )
    plt.savefig(f"figs/{os.path.basename(model_path.split('/')[-1])}_tsne.png")
    

    print(len(all_selected_hiddens))


    # input("?")
    # for group in data:
    #     solutions = group['responses']
    #     print(list(group.keys()))
    #     prompt = group['prompt']
    #     prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
    #     input("?")
    #     solutions = [solution.split("</think>")[0] for solution in solutions]
    #     correctness = group['correctness']
    #     # solution_lengths = [len(tokenizer(solution)['input_ids']) for solution in solutions]
    #     is_thinking_cot = [int(use_think(solution)) for solution in solutions]
    #     current_is_thinking_cot.extend(is_thinking_cot)
    #     current_correctness.extend(correctness)
        
    #     for is_think, correct in zip(is_thinking_cot, correctness):
    #         count += 1
    #         if is_think == 1:
    #             thinking_cot_count += 1
    #             thinking_cot_correct += correct
    #         else:
    #             non_thinking_cot_count += 1
    #             non_thinking_cot_correct += correct
    
    
    print(f"Dataset: {data_path}")
    print("count:", count)
    print(f"Thinking COT ratio: {thinking_cot_count / count:.2f}")
    print(f"Non-thinking COT ratio: {non_thinking_cot_count / count:.2f}")
    print(f"Thinking COT accuracy: {thinking_cot_correct / thinking_cot_count:.2f}")
    print(f"Non-thinking COT accuracy: {non_thinking_cot_correct / non_thinking_cot_count:.2f}")
    print("-"*20)



    # # 绘制当前数据集的 KDE
    # if current_is_thinking_cot: # 确保有数据可绘制
    #     # 从文件路径提取文件名作为标签
    #     label = os.path.basename(data_path)
    #     # sns.kdeplot(current_lengths, ax=ax, fill=False, label=label, bw_adjust=0.2) # fill=False 不填充颜色
    #     plt.hist(current_is_thinking_cot, bins=2)

# 添加标题和标签
# plt.title('Kernel Density Estimate of Solution Lengths for Multiple Datasets')
# plt.xlabel('Solution Lengths')
# plt.ylabel('Density')

# # 添加图例
# plt.legend(title="Datasets")

# # 显示图形
# # plt.grid(True, linestyle='--', alpha=0.6) # 添加网格线
# plt.show()
# plt.savefig("figs/lengths.png")

