import json
from transformers import AutoTokenizer
from utils import load_eval_data
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import json
from utils import load_eval_data
from deepscaler.rewards.math_utils.utils import extract_answer, grade_answer_sympy as grade_answer
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import seaborn as sns
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import torch
import random

import json
# from utils import extract_answer
# from grader import grade_answer
from datasets import load_from_disk
from datasets import load_dataset
import random
from transformers import AutoTokenizer
import numpy as np
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import random
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import json
from utils import load_eval_data
from deepscaler.rewards.math_utils.utils import extract_answer, grade_answer_sympy as grade_answer

random.seed(0)


import numpy as np

def normalize_gain(lst, baseline=0):
    arr = np.array(lst, dtype=np.float32)
    std = np.std(arr)
    print("std:", std) 
    return (arr-baseline) / std

def transform(arr):
    # if element > 0, add 1; if element < 0, minus 1
    return [element+1 if element >= 0 else element-1 for element in arr]

def format_dataset(input_path, output_path):

    def format(example):
        example["messages"] = [{"role":"user","content":f"{example['problem']} Let's think step by step and output the final answer within \\boxed{{}}."},{"role":"assistant","content":example['solution']}]
        return example

    ds = load_dataset("json", data_files=input_path)

    columns_to_remove = ['ground_truth_solution','ground_truth_answer', 'pre_generated_steps', 'pre_generated_answer', 'pre_generated_verifier_score']
    for column in columns_to_remove:
        if column in list(ds['train'][0].keys()):
            ds = ds.remove_columns([column])
            print("remove column:", column)

    ds = ds.map(format)
    ds = ds.remove_columns(['problem', 'solution'])

    print(ds)
    print(ds['train'][0])
    print(ds['train'][1])
    ds.save_to_disk(output_path)

def format_pairwise_sample(chosen_item, rejected_item, weight):
    prompt = chosen_item['prompt']
    chosen_response = chosen_item['solution']
    rejected_response = rejected_item['solution']

    sample = {
        "instruction": prompt,
        "chosen": chosen_response,
        "rejected": rejected_response
    }

    if weight != None:
        sample['weight'] = weight
    
    return sample

def format_sft_sample(item):
    prompt = item['prompt']
    response = item['solution']
    sample = {
        "messages":[
            {"role":"user","content":prompt},
            {"role":"assistant","content":response}
        ]
    }
    return sample

def filter_group(group, ground_truth_answer, filter_type):
    if filter_type == "wrong":
        return [item for item in group if grade_answer(extract_answer(item['solution']), ground_truth_answer) == False]
    if filter_type == "correct":
        # for item in group:
        #     answer = extract_answer(item['solution'])
        #     print(answer,ground_truth_answer,grade_answer(answer, ground_truth_answer))
        # print("-"*10)
        return [item for item in group if grade_answer(extract_answer(item['solution']), ground_truth_answer) == True]
    
def construct_pairwise_json_data(data, max_pairs, with_weight=False, bi_level=False):
    output_data = []

    for item in data:
        gain = item['gain']
        ground_truth_answer = item['ground_truth_answer']

        if gain <= 0:
            chosen_group = item['short_group']
            rejected_group = item['long_group']
        
        if gain > 0:
            chosen_group = item['long_group']
            rejected_group = item['short_group']
        
        correct_chosen_group = filter_group(chosen_group, ground_truth_answer, "correct")

        if len(correct_chosen_group) != 0:
            chosen_group = correct_chosen_group
        
        inner_group_samples = []
        if bi_level:

            shorest_item, longest_item = None, None

            correct_chosen_group = filter_group(chosen_group, ground_truth_answer, "correct")
            wrong_chosen_group = filter_group(chosen_group, ground_truth_answer, "wrong")
            
            correct_chosen_group_lengths = [len(tokenizer(item['solution'])['input_ids']) for item in correct_chosen_group]
            wrong_chosen_group_lengths = [len(tokenizer(item['solution'])['input_ids']) for item in wrong_chosen_group]

            # if len(correct_chosen_group) != 0 and len(wrong_chosen_group) != 0:
            #     shorest_item = correct_chosen_group[correct_chosen_group_lengths.index(min(correct_chosen_group_lengths))]
            #     longest_item = wrong_chosen_group[wrong_chosen_group_lengths.index(max(wrong_chosen_group_lengths))]
                
            # if len(correct_chosen_group) != 0 and len(wrong_chosen_group) == 0:
            #     shorest_item = correct_chosen_group[correct_chosen_group_lengths.index(min(correct_chosen_group_lengths))]
            #     longest_item = correct_chosen_group[correct_chosen_group_lengths.index(max(correct_chosen_group_lengths))]

            # if len(correct_chosen_group) == 0 and len(wrong_chosen_group) != 0:
            #     shorest_item = wrong_chosen_group[wrong_chosen_group_lengths.index(min(wrong_chosen_group_lengths))]
            #     longest_item = wrong_chosen_group[wrong_chosen_group_lengths.index(max(wrong_chosen_group_lengths))]
            M = 2
            if len(correct_chosen_group) != 0 and len(wrong_chosen_group) != 0:
                shortest_idx = correct_chosen_group_lengths.index(min(correct_chosen_group_lengths))
                longest_indices = sorted(range(len(wrong_chosen_group_lengths)), key=lambda i: -wrong_chosen_group_lengths[i])[:M]
                shorest_item = correct_chosen_group[shortest_idx]
                longest_item = [wrong_chosen_group[i] for i in longest_indices]

            elif len(correct_chosen_group) != 0 and len(wrong_chosen_group) == 0:
                sorted_indices = sorted(range(len(correct_chosen_group_lengths)), key=lambda i: correct_chosen_group_lengths[i])
                shorest_item = correct_chosen_group[sorted_indices[0]]
                longest_item = [correct_chosen_group[i] for i in sorted_indices[-M:]]

            elif len(correct_chosen_group) == 0 and len(wrong_chosen_group) != 0:
                sorted_indices = sorted(range(len(wrong_chosen_group_lengths)), key=lambda i: wrong_chosen_group_lengths[i])
                shorest_item = wrong_chosen_group[sorted_indices[0]]
                longest_item = [wrong_chosen_group[i] for i in sorted_indices[-M:]]

            for long_item in longest_item:
                inner_group_samples.append(format_pairwise_sample(shorest_item, long_item, 1))
                # print("[shorest_item]:",shorest_item)
                # print("[longest_item]:",long_item)
                # print("-"*100)
                # input("continue?")
            


        
        wrong_rejected_group = filter_group(rejected_group, ground_truth_answer, "wrong")
        if len(wrong_rejected_group) != 0:
            rejected_group = wrong_rejected_group
        
        all_samples = []
        for chosen_item in chosen_group:
            for rejected_item in rejected_group:
                if with_weight:
                    weight = abs(gain)
                    all_samples.append(format_pairwise_sample(chosen_item, rejected_item, 1))
                else:
                    all_samples.append(format_pairwise_sample(chosen_item, rejected_item, 1))

        random.shuffle(all_samples)
        all_samples = all_samples[0:max_pairs]
        output_data += all_samples # inter-group
        output_data += inner_group_samples # inner-group
    return output_data

def json_to_dataset(data, output_path):
    # save json as temp file
    with open("temp.json", "w") as f:    
        json.dump(data, f)

    ds = load_dataset("json", data_files="temp.json")
    print(ds)
    print(ds['train'][0])
    if "weight" in ds['train'][0].keys():
        print(ds['train'][0]['weight'])
        print(ds['train'][1]['weight'])
        print(ds['train'][2]['weight'])
        print(ds['train'][3]['weight'])
        print(ds['train'][4]['weight'])
    # print(ds['train'][1])
    ds.save_to_disk(output_path)
    print("dataset saved to", output_path)
    return ds

model_path = "Qwen/Qwen2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_path)

long_cot_data_path = "model_eval/DeepSeek-R1-Distill-Qwen-7B/mix_mathematic_problems.json"
short_cot_data_path = "model_eval/Deepseek-Qwen-7B-Short-COT/mix_mathematic_problems.json"

# long_cot_data_path = "model_eval/DeepSeek-R1-Distill-Qwen-1.5B/mix_mathematic_problems.json"
# short_cot_data_path = "model_eval/Deepseek-Qwen-1.5B-Short-COT/mix_mathematic_problems.json"

# long_cot_data_path = "length_control/data/model_generated/QwQ-32B-Preview_numina_16384_K-8.json"
# short_cot_data_path = "length_control/data/model_generated/Qwen2.5-32B-Instruct_numina_4096_K-8.json"

# long_cot_data_path = "length_control/data/model_generated/QwQ-32B-Preview_math_train_16384_K-8.json"
# short_cot_data_path = "length_control/data/model_generated/Qwen2.5-32B-Instruct_math_train_4096_K-8.json"

long_cot_data = load_eval_data(long_cot_data_path)
short_cot_data = load_eval_data(short_cot_data_path)

K = len(long_cot_data[0]['responses'])
print("K:", K)

assert len(long_cot_data) == len(short_cot_data)

print("num problems:",len(long_cot_data))

negative = 0
postive = 0

long_acc_random = 0
short_acc_random = 0

valid_count = 0

all_gain = []

selected_data = []

total_acc_long = 0
total_length_long = 0

total_acc_short = 0
total_length_short = 0

total_acc_optimal = 0
total_length_optimal = 0

acc_counters = [0,0,0] # >0 =0 <0
long_lengths = []
accuracy_diffs = []
# for group_index in range(len(long_cot_data)):

data = []

for group_index in range(500):

    long_group = long_cot_data[group_index]
    short_group = short_cot_data[group_index]

    # print(list(long_group.keys()))
    # input("continue?")

    # calculate correctness
    ground_truth_answer = long_group['reward_model']['ground_truth']

    # skip multiple choice questions
    if ground_truth_answer in ["A", "B", "C", "D", "E", "F", "\\text{A}", "\\text{B}", "\\text{C}", "\\text{D}", "\\text{E}", "\\text{F}","\\boxed{A}", "\\boxed{B}", "\\boxed{C}", "\\boxed{D}", "\\boxed{E}", "\\boxed{F}"]:
        continue
    if ground_truth_answer =="None" or ground_truth_answer == "":
        continue

    long_correctness = [int(x) for x in long_group['correctness']]
    short_correctness = [int(x) for x in short_group['correctness']]

    long_accuracy = sum(long_correctness) / len(long_correctness)
    short_accuracy = sum(short_correctness) / len(short_correctness)

    relative_accuracy_gain = long_accuracy - short_accuracy - 1/(2*K)

    # a special case
    if long_accuracy == 0 and short_accuracy == 0:
        continue
    
    is_long_cot = bool(relative_accuracy_gain>0)

    data.append(
        {
            "prompt": long_group['prompt'][-1]['content'],
            "is_long_cot": is_long_cot
        }
    )

print(data[0])
        
def calculate_thinking_freq(solution):
    solution = solution.split("</think>")[0].lower()
    key_words = ["wait", "hmm", "remember", "recheck"]
    # calculate sum of count of key words in solution
    count = 0
    for key_word in key_words:
        count += solution.count(key_word)
    return count

def use_think(solution):
    solution = solution.split("</think>")[0].lower()
    key_words = ["wait", "hmm", "remember", "recheck"]
    # calculate sum of count of key words in solution
    count = 0
    for key_word in key_words:
        count += solution.count(key_word)
    return count > 0

def is_long_cot(prompt, average_length, reference_data):
    prompt_content = prompt[-1]['content']
    ref_average_length = reference_data[prompt_content]
    if average_length > 2.5 * ref_average_length:
        return True

model_paths = [
    "models/Deepseek-Qwen-7B/Deepseek-Qwen-7B-merge-0.8-dpo-beta-0.1-no-ln-bilevel-fulldata-M1-4-M2-2",
    "models/Deepseek-Qwen-7B/Deepseek-Qwen-7B-dpo/checkpoint-250",
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    "models/Deepseek-Qwen-7B/long_0.8_short_0.2"
]

model_names = [
    "Ada-R1-7B",
    "DPO-7B",
    "R1-7B",
    "Merge-7B",
]

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
layer = 1
# 遍历每个数据文件并绘制 KDE
for model_path, model_name in zip(model_paths, model_names):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()
    tokenizer.padding_side = "right"
    thinking_cot_correct = 0
    thinking_cot_count = 0
    non_thinking_cot_correct = 0
    non_thinking_cot_count = 0
    
    current_is_thinking_cot = []
    current_correctness = []
    count = 0

    all_prompts = []
    print(data[0])
    for group in data:
        prompt = group['prompt']
        # prompt = prompt.split("Let's think step by step")[0].strip()
        # print("[loop]:",prompt)
        prompt = tokenizer.apply_chat_template([{"role":"user","content":prompt}], tokenize=False, add_generation_prompt=True)
        # print(prompt)
        all_prompts.append(prompt)
    
    print("all_prompts[0]:")
    print(all_prompts[0])
    
    batch_size = 10
    all_selected_hiddens = []

    for i in range(0, len(all_prompts), batch_size):
        batch_prompts = all_prompts[i:i+batch_size]
        input_encodings = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False).to("cuda")
        attention_mask = input_encodings['attention_mask']
        input_lengths = attention_mask.sum(dim=1)
        # print(input_lengths)

        with torch.no_grad():
            outputs = model(**input_encodings, output_hidden_states=True)


        hiddens = outputs.hidden_states[-layer]
        last_token_indices = input_lengths - 1
        batch_indices = torch.arange(hiddens.shape[0], device=hiddens.device) # Ensure indices are on the same device as hiddens
        selected_hiddens = hiddens[batch_indices, last_token_indices, :]
        # print(selected_hiddens.shape)
        selected_hiddens = selected_hiddens.float().cpu().numpy()
        all_selected_hiddens.extend(selected_hiddens.tolist())
    

    all_is_thinking_cot = []

    for group in data:
        prompt = group['prompt']
        is_thinking_cot = group['is_long_cot']

        all_is_thinking_cot.append(is_thinking_cot)


    point_colors = ['red' if label == 1 else 'aqua' for label in all_is_thinking_cot]
    X = np.stack(all_selected_hiddens, axis=0)

    tsne = TSNE(n_components=2,
                    perplexity=100,    # Adjust if needed (e.g., lower for smaller datasets)
                    n_iter=2000,      # Can increase if optimization doesn't converge
                    init='pca',       # Often a good starting point
                    learning_rate='auto',
                    random_state=42)  # Use a fixed seed for reproducibility
    results = tsne.fit_transform(X)

    # pca = PCA(n_components=2)
    # results = pca.fit_transform(X)

    plt.figure(figsize=(12, 10)) # Adjust figure size as needed

    labels = np.array(all_is_thinking_cot)  # 0 or 1
    X_2d = results

    plt.figure(figsize=(12, 10))

    for label, color, name in zip([1, 0], ['red', 'aqua'], ['Long Preferred (gain>0)', 'Short Preferred (gain<0)']):
        plt.scatter(
            X_2d[labels == label, 0],
            X_2d[labels == label, 1],
            alpha=0.7,
            s=55,
            c=color,
            edgecolors='k',
            label=name
        )

    plt.legend(fontsize=18, loc='upper left', bbox_to_anchor=(0, 1))
    plt.title(model_name, fontsize=24)
    plt.tight_layout()

    # plt.xlabel('Solution Lengths')
    # plt.ylabel('Density')
    plt.savefig(f"figs/hidden-{layer}/{model_name}_tsne.png")

