from Compress import generate,FormatPrompt,model_compress
from CoTGeneration import readGSM8K
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
import os
import argparse
from openai import OpenAI


def JsonLoad(data_path):
    with open(data_path,'r',encoding='utf-8') as json_file:
        data = [json.loads(line) for line in json_file]
    return data

def ThinkFirstGenerate(dataset_name,data_path,device_map,think_model_path,generate_model_path,output_path,file_type='parquet'):
    dataset = output_path.split("/")[-2]
    os.makedirs(f"./KoTLlmOutput/{dataset}", exist_ok=True)
    if file_type == 'json':
        data = JsonLoad(data_path)
    elif file_type == 'parquet':    
        data = readGSM8K(data_path)

    if 'math' in dataset_name:
        golden_answer_key = 'answer'
        question_key = 'problem'
    elif 'GPQA' in dataset_name:
        golden_answer_key ='answer'
        question_key = 'question'
    elif 'gsm8k' in dataset_name:
        golden_answer_key = 'answer'
        question_key = 'question'

    think_model = AutoModelForCausalLM.from_pretrained(think_model_path, device_map=f"cuda:{device_map}", torch_dtype=torch.bfloat16)
    think_model_tokenizer = AutoTokenizer.from_pretrained(think_model_path,device_map=f"cuda:{device_map}") 

    generate_model = AutoModelForCausalLM.from_pretrained(generate_model_path, device_map=f"cuda:{str(int(device_map) + 1)}", torch_dtype=torch.bfloat16)
    generate_model_tokenizer = AutoTokenizer.from_pretrained(generate_model_path,device_map=f"cuda:{str(int(device_map) + 1)}")

    with open(output_path, 'w', encoding="utf-8") as f_out:
        for i in range(len(data)):
            golden_answer = data[i][golden_answer_key]
            question = data[i][question_key] 
            # question = data[i]['question']
            messages = [
                    {"role": "system", "content": "Analyze the problem below and determine the required formulas or theoretical principles, and potential solution strategies. Present this as a structured narrative in plain English. Focus on identifying foundational concepts and logical frameworks needed to address the problem, explicitly avoiding numerical calculations. The response must outline what knowledge or formulas are necessary and how they might be applied conceptually, while excluding implementation details, examples, or mathematical operations. Ensure the analysis is purely textual, continuous, and adheres strictly to these constraints."},
                    {"role": "user", "content": 'Question:' + '<'+ question + '>'},
                ]
            tokenized_chat = think_model_tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map}")
            prompt_len = len(think_model_tokenizer.apply_chat_template(messages,add_generation_prompt=True,tokenize=False))
            outputs = think_model.generate(tokenized_chat, pad_token_id=think_model_tokenizer.eos_token_id,max_new_tokens=16384) 
            output = think_model_tokenizer.decode(outputs[0])
            think = output[prompt_len:]
            messages = '<｜User｜>' + question + 'Please reason step-by-step and put your final answer within \\boxed{} in the end. <｜Assistant｜>' + '''<think>''' + think 
            tokenized_chat = generate_model_tokenizer.encode(messages,return_tensors='pt')
            tokenized_chat = tokenized_chat.to(device=f"cuda:{str(int(device_map) + 1)}")
            outputs = generate_model.generate(tokenized_chat,max_new_tokens=16384,output_scores=True,return_dict_in_generate = True) 
            output = generate_model_tokenizer.decode(outputs[0][0][tokenized_chat.shape[-1]:])

            print(question)
            # print(output[prompt_len:])
            print('-'*86)
            f_out.write(json.dumps({"input": question, "output": golden_answer, "answer": output ,"think": think},ensure_ascii=False) + "\n")


import pandas as pd
def CSVLoad(data_path):
    all_files = [f for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f))]
    data = []
    for s in all_files:
        df = pd.read_csv(data_path+s)
        question_num = df.shape[0]
        for q in range(int(question_num)):
            single = {}
            question =  str(df.iloc[q][0]) + '\n\n\nA. ' + str(df.iloc[q][1]) + '\nB. ' +  str(df.iloc[q][2]) + '\nC. ' +  str(df.iloc[q][3]) + '\nD. ' +  str(df.iloc[q][4]) + '\n\n'
            correct_answer = df.iloc[q][5]
            single['question'] = question
            single['answer'] = correct_answer
            data.append(single)
    return data


def DirectGenerate(dataset_name,data_path,device_map,model_path,output_path,reference,file_type='parquet'):
    dataset = output_path.split("/")[-2]
    os.makedirs(f"./KoTLlmOutput/{dataset}", exist_ok=True)
    if file_type == 'json':
        data = JsonLoad(data_path)
    elif file_type == 'csv':
        data = CSVLoad(data_path)
    elif file_type == 'parquet':    
        data = readGSM8K(data_path)

    if ('math' in dataset_name) or ('aime' in dataset_name):
        golden_answer_key = 'answer'
        question_key = 'problem'
    elif 'gsm8k' in dataset_name or 'gpqa' in dataset_name or 'limo' in dataset_name or 'amc' in dataset_name:
        golden_answer_key = 'answer'
        question_key = 'question'
    elif 'reference' in dataset_name:
        golden_answer_key = 'output'
        question_key = 'input'
        reference_key = 'answer'

    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path,device_map="auto") 
    # model = AutoModelForCausalLM.from_pretrained(model_path, device_map=f"cuda:{device_map}", torch_dtype=torch.bfloat16)
    # tokenizer = AutoTokenizer.from_pretrained(model_path,device_map=f"cuda:{device_map}") 
    deepseek_tokenize = AutoTokenizer.from_pretrained('../All_LLM/DeepSeek-R1-Distill-Qwen-32B/',device_map=f"cuda:{device_map}")
    with open(output_path, 'w', encoding="utf-8") as f_out:
        for i in range(len(data)):
            golden_answer = data[i][golden_answer_key]
            question = data[i][question_key] 
            # question = data[i]['question']
            if 'math' in dataset_name or 'gpqa' or 'gsm8k' in dataset_name or 'limo' in dataset_name or 'amc' in dataset_name:
                messages = [{"role": "user", "content": question+ 'Please reason step-by-step and put your final answer within \\boxed{}.'}]
                text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
                # messages = '<|im_start|>user' + question + 'Please reason step-by-step and put your final answer within \\boxed{}.<|im_end|><|im_start|>assistant'

                print(messages)
            if reference:
                reference = data[i][reference_key]
                reference = deepseek_tokenize.encode(reference)
                try:
                    index = reference.index(151649)
                    reference = deepseek_tokenize.decode(reference[index+1:index+129])
                except:
                    reference = deepseek_tokenize.decode(reference[:32])
                messages = '<|im_start|>user' + question + 'Please reason step-by-step and put your final answer within \\boxed{}. <|im_end|><|im_start|>assistant'+reference
            # direct
            # messages = '<|im_start|>user' + question + 'Please reason step-by-step and put your final answer within \\boxed{} in the end. <|im_end|><|im_start|>assistant' 
            tokenized_chat = tokenizer.encode(text,return_tensors='pt')
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map}")
            outputs = model.generate(tokenized_chat,max_new_tokens=8192,output_scores=True,return_dict_in_generate=True,do_sample=False) 
            output = tokenizer.decode(outputs[0][0][tokenized_chat.shape[-1]:])
            # messages = [{"role": "user", "content": question+'Please reason step by step, and put your final answer within \boxed{}.'},]
            # tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
            # tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map}")
            # prompt_len = len(tokenizer.apply_chat_template(messages,add_generation_prompt=True,tokenize=False))
            # outputs = model.generate(tokenized_chat, pad_token_id=tokenizer.eos_token_id,max_new_tokens=16384) 
            # output = tokenizer.decode(outputs[0])
            # print(output[prompt_len:])
            f_out.write(json.dumps({"input": question, "output": golden_answer, "answer": output,'reference':reference},ensure_ascii=False) + "\n")

def SimpleGenerate(data_path,device_map,model_path,output_path,file_type='parquet'):
    dataset = output_path.split("/")[-2]
    os.makedirs(f"./KoTLlmOutput/{dataset}", exist_ok=True)
    if file_type == 'json':
        data = JsonLoad(data_path)
    elif file_type == 'parquet':    
        data = readGSM8K(data_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=f"cuda:{device_map}", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path,device_map=f"cuda:{device_map}") 
    with open(output_path, 'w', encoding="utf-8") as f_out:
        for i in range(len(data)):
            golden_answer = data[i]['answer']
            question = data[i]['question'] + '<|eot_id|>' + '<|eot_id|>'
            # question = data[i]['question']
            messages = [
                    {"role": "system", "content": "You are an expert in math.\nBelow is a math question.\nPlease think step by step, Write a response that appropriately answers the question.\nYour final answer should be an integer at the end of your response, formatted as: The answer is {answer}."},
                    {"role": "user", "content": question},
                ]
            tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map}")
            prompt_len = len(tokenizer.apply_chat_template(messages,add_generation_prompt=True,tokenize=False))
            outputs = model.generate(tokenized_chat, pad_token_id=tokenizer.eos_token_id,max_new_tokens=2048) 
            output = tokenizer.decode(outputs[0])
            print(question)
            # print(output[prompt_len:])
            print('-'*86)
            f_out.write(json.dumps({"input": question, "output": golden_answer, "answer": output[prompt_len:]},ensure_ascii=False) + "\n")

def KoTgenerate(data_path,device_map,model_path,output_path,file_type='parquet'):
    dataset = output_path.split("/")[-2]
    os.makedirs(f"./KoTLlmOutput/{dataset}", exist_ok=True)
    if file_type == 'json':
        data = JsonLoad(data_path)
    elif file_type == 'parquet':    
        data = readGSM8K(data_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=f"cuda:{device_map}", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path,device_map=f"cuda:{device_map}") 
    with open(output_path, 'w', encoding="utf-8") as f_out:
        for i in range(len(data)):
            golden_answer = data[i]['answer']
            len_matches = 0 
            # while True: 
            messages = FormatPrompt(golden_answer)
            answer = Qwen_generate(messages)
            # while True:
            #     answer = model_compress(model,tokenizer,messages,device_map)
            #     print(answer)
            #     match = re.findall(r"\[([^\]]+)\]", answer)  
            #     if match:
            #         items = [item.strip(" '\"") for item in match[0].split(",")]
            #         break
            # answer = generate(messages)

            #     matches = re.findall(r"\*\*(.*?)\*\*", answer)
            #     len_matches = len(matches)
            #     if 3 < len_matches < 10:
            #         break
            # key_str = '['
            # for k in items:
            #     key_str +=  k+','
            # key_str += ']'
            # print(answer)
            question = data[i]['question'] + '<|eot_id|>' + answer + '<|eot_id|>'
            messages = [
                    {"role": "system", "content": "You are an expert in math.\nBelow is a math question.\nPlease think step by step, Write a response that appropriately answers the question.\nYour final answer should be an integer at the end of your response, formatted as: The answer is {answer}."},
                    {"role": "user", "content": question},
                ]
            tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map}")
            prompt_len = len(tokenizer.apply_chat_template(messages,add_generation_prompt=True,tokenize=False))
            outputs = model.generate(tokenized_chat, pad_token_id=tokenizer.eos_token_id,max_new_tokens=2048) 
            output = tokenizer.decode(outputs[0])
            print(question)
            # print(output[prompt_len:])
            print('-'*86)
            f_out.write(json.dumps({"input": question, "output": golden_answer, "answer": output[prompt_len:]},ensure_ascii=False) + "\n")

def Regenerate(data_path,device_map,model_path,output_path,file_type):
    dataset = output_path.split("/")[-2]
    os.makedirs(f"./KoTLlmOutput/{dataset}", exist_ok=True)
    if file_type == 'json':
        data = JsonLoad(data_path)
    elif file_type == 'parquet':    
        data = readGSM8K(data_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=f"cuda:{device_map}", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path,device_map=f"cuda:{device_map}") 
    with open(output_path, 'w', encoding="utf-8") as f_out:
        for i in range(len(data)):
            input = data[i]['input']
            output1 = data[i]['output']
            messages = [
                    {"role": "system", "content": "You are an expert in math.\nBelow is a math question.\nPlease think step by step, Write a response that appropriately answers the question.\nYour final answer should be an integer at the end of your response, formatted as: The answer is {answer}."},
                    {"role": "user", "content": data[i]['input']},
                ]
            tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map}")
            prompt_len = len(tokenizer.apply_chat_template(messages,add_generation_prompt=True,tokenize=False))
            outputs = model.generate(tokenized_chat, pad_token_id=tokenizer.eos_token_id,max_new_tokens=2048) 
            output = tokenizer.decode(outputs[0])
            print(input)
            print(output[prompt_len:])
            print('-'*86)
            f_out.write(json.dumps({"input": input, "output": output1, "answer": output[prompt_len:]},ensure_ascii=False) + "\n")


def GreedyGenerate(dataset_name,data_path,device_map,model_path,output_path,file_type='parquet'):
    dataset = output_path.split("/")[-2]
    os.makedirs(f"./KoTLlmOutput/{dataset}", exist_ok=True)
    if file_type == 'json':
        data = JsonLoad(data_path)
    elif file_type == 'parquet':    
        data = readGSM8K(data_path)

    if 'math' in dataset_name:
        golden_answer_key = 'answer'
        question_key = 'problem'
    elif 'gsm8k' in dataset_name or 'gpqa' in dataset_name or 'limo' in dataset_name:
        golden_answer_key = 'answer'
        question_key = 'question'
    elif 'reference' in dataset_name:
        golden_answer_key = 'output'
        question_key = 'input'
        reference_key = 'answer'

    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=f"cuda:{device_map}", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path,device_map=f"cuda:{device_map}") 
    deepseek_tokenize = AutoTokenizer.from_pretrained('../All_LLM/DeepSeek-R1-Distill-Qwen-32B/',device_map=f"cuda:{device_map}")
    with open(output_path, 'w', encoding="utf-8") as f_out:
        for i in range(len(data)):
            golden_answer = data[i][golden_answer_key]
            question = data[i][question_key] 
            # question = data[i]['question']
            messages = '<|im_start|>user' + question + 'Please reason step-by-step and put your final answer within \\boxed{}. <|im_end|><|im_start|>assistant'
            # direct
            tokenized_chat = tokenizer.encode(messages,return_tensors='pt')
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map}")
            outputs = model.generate(tokenized_chat,max_new_tokens=16384,output_scores=True,return_dict_in_generate = True) 
            output = tokenizer.decode(outputs[0][0][tokenized_chat.shape[-1]:])
            f_out.write(json.dumps({"input": question, "output": golden_answer, "answer": output},ensure_ascii=False) + "\n")

from torch import nn
class SimpleMLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.relu1 = nn.ReLU()
        # self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.relu1(self.bn1(self.fc1(x)))
        x = self.relu2(self.bn2(self.fc2(x)))
        return self.fc3(x)




from torch import nn
from torch.nn import functional as F
def compute_entropy(logits: torch.Tensor) -> torch.Tensor:
    probs = F.softmax(logits, dim=-1)  # shape: (B, T, V)
    log_probs = torch.log2(probs + 1e-8)  
    entropy_value = -torch.sum(probs * log_probs, dim=-1)  # shape: (B, T)
    return entropy_value

import numpy as np
def SpeculativeGenerate(dataset_name,data_path,device_map,model,model_2,output_path,reference,chosen_entropy,file_type='parquet'):
    model_path='../models/Qwen2.5-72B-Instruct/'
    model_path_2='../Qwen2.5-7B-Instruct/'
    from Spec.hf_generation import my_generate
    from torch import nn
    from transformers import modeling_utils
    if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
        modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none","colwise",'rowwise']
    
        
    original_logits = []
    new_logits = []

    device = torch.device(f"cuda:{device_map}" if torch.cuda.is_available() else "cpu")
    
    import time

    start_time = time.time()  

    dataset = output_path.split("/")[-2]
    os.makedirs(f"./KoTLlmOutput/{dataset}", exist_ok=True)
    if file_type == 'json':
        data = JsonLoad(data_path)
    elif file_type == 'csv':
        data = CSVLoad(data_path)
    elif file_type == 'parquet':    
        data = readGSM8K(data_path)

    if ('math' in dataset_name) or ('aime' in dataset_name):
        golden_answer_key = 'answer'
        question_key = 'problem'
    elif 'gsm8k' in dataset_name or 'gpqa' in dataset_name or 'limo' in dataset_name or 'amc' in dataset_name:
        golden_answer_key = 'answer'
        question_key = 'question'
    elif 'reference' in dataset_name:
        golden_answer_key = 'output'
        question_key = 'input'
        reference_key = 'answer'

    tokenizer = AutoTokenizer.from_pretrained(model_path,device_map='auto') 
    tokenizer_2 = AutoTokenizer.from_pretrained(model_path_2,device_map=f"cuda:{device_map-1}") 

    with open(output_path, 'w', encoding="utf-8") as f_out:
        for i in range(len(data)):
            entropy = []
            golden_answer = data[i][golden_answer_key]
            question = data[i][question_key] 
            if ('math' in dataset_name) or ('gsm8k' in dataset_name) or ('limo' in dataset_name) or ('amc' in dataset_name) or ('aime' in dataset_name):
                format_messages = [{"role": "user", "content": question + 'Please reason step-by-step and put your final answer within \\boxed{}.'}]
                messages = tokenizer.apply_chat_template(format_messages,tokenize=False,add_generation_prompt=True)
            else:
                format_messages = [{"role": "user", "content": question + 'Please reason step-by-step and put your choice letter without any other text with \\boxed{} in the end.'}]
                messages = tokenizer.apply_chat_template(format_messages,tokenize=False,add_generation_prompt=True)
                print(messages)

            tokenized_chat = tokenizer.encode(messages,return_tensors='pt')
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map-1}")
            

            chosen_entropy = chosen_entropy

            # This setting here is to ensure reproducibility
            model_2.generation_config.assistant_confidence_threshold = 0.4
            outputs = model.generate(tokenized_chat,max_new_tokens=4096,output_scores=True,return_dict_in_generate = True,output_hidden_states=True,assistant_model=model_2,do_sample=False,chosen_entropy=chosen_entropy) 

            output = tokenizer.decode(outputs[0][0][tokenized_chat.shape[-1]:])
            for tensor in outputs.scores:
                single_entropy = float(compute_entropy(tensor))
                entropy.append(single_entropy)
            f_out.write(json.dumps({"input": messages, "output": golden_answer, "answer": output,'reference':outputs.caandidate_list,'valid_result':outputs.valid_result,'matches':outputs.matches,'entropy':entropy,'punish_list':outputs.whole_punish_list,'chosen_entropy':chosen_entropy},ensure_ascii=False)+ "\n")

            del outputs

        end_time = time.time()  
        elapsed_time = end_time - start_time  
        print(f"Running Time: {elapsed_time:.4f} 秒")


def BeamSearch(dataset_name,data_path,device_map,model,output_path,file_type='parquet'):
    from transformers import modeling_utils
    if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
        modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none","colwise",'rowwise']

    dataset = output_path.split("/")[-2]
    os.makedirs(f"./KoTLlmOutput/{dataset}", exist_ok=True)
    if file_type == 'json':
        data = JsonLoad(data_path)
    elif file_type == 'csv':
        data = CSVLoad(data_path)
    elif file_type == 'parquet':    
        data = readGSM8K(data_path)

    if ('math' in dataset_name) or ('aime' in dataset_name):
        golden_answer_key = 'answer'
        question_key = 'problem'
    elif 'gsm8k' in dataset_name or 'gpqa' in dataset_name or 'limo' in dataset_name or 'amc' in dataset_name:
        golden_answer_key = 'answer'
        question_key = 'question'
    elif 'reference' in dataset_name:
        golden_answer_key = 'output'
        question_key = 'input'
        reference_key = 'answer'
    
    model_path =  '../Qwen2.5-7B-Instruct/'
    tokenizer = AutoTokenizer.from_pretrained(model_path,device_map='auto') 
    with open(output_path, 'w', encoding="utf-8") as f_out:
        for i in range(len(data)):
            golden_answer = data[i][golden_answer_key]
            question = data[i][question_key] 
            if ('math' in dataset_name) or ('gsm8k' in dataset_name) or ('limo' in dataset_name) or ('amc' in dataset_name) or ('aime' in dataset_name):
                format_messages = [{"role": "user", "content": question + 'Please reason step-by-step and put your final answer within \\boxed{}.'}]
                messages = tokenizer.apply_chat_template(format_messages,tokenize=False,add_generation_prompt=True)
            else:
                format_messages = [{"role": "user", "content": question + 'Please reason step-by-step and put your choice letter without any other text with \\boxed{} in the end.'}]
                messages = tokenizer.apply_chat_template(format_messages,tokenize=False,add_generation_prompt=True)

            tokenized_chat = tokenizer.encode(messages,return_tensors='pt')
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map-1}")

            outputs = model.generate(
                tokenized_chat,
                num_beams=16, 
                num_return_sequences=1,
                max_new_tokens=4096
            )

            for i, output in enumerate(outputs):
                generate_output = tokenizer.decode(output, skip_special_tokens=True)
            f_out.write(json.dumps({"input": messages, "output": golden_answer, "answer": generate_output,},ensure_ascii=False)+ "\n")
   
import numpy as np
def DeepseekMathSpeculativeGenerate(device_map,model,model_2,output_path):
    model_path='../Qwen2.5-72B-Instruct/'
    model_path_2='../Qwen2.5-7B-Instruct/'
    from Spec.hf_generation import my_generate
    from torch import nn
    from transformers import modeling_utils
    if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
        modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none","colwise",'rowwise']

     
    tokenizer = AutoTokenizer.from_pretrained(model_path,device_map='auto') 
    tokenizer_2 = AutoTokenizer.from_pretrained(model_path_2,device_map=f"cuda:{device_map-1}") 

    df = pd.read_parquet('../KoT/Dataset/DeepMath-103K/')
    with open(output_path, 'w', encoding="utf-8") as f_out:
        for idx, row in df.iterrows():
            entropy = []
            question = row['question']
            golden_answer = row['final_answer']
            difficult = row['difficulty']

            format_messages = [{"role": "user", "content": question + 'Please reason step-by-step and put your final answer within \\boxed{}.'}]
            messages = tokenizer.apply_chat_template(format_messages,tokenize=False,add_generation_prompt=True)
            tokenized_chat = tokenizer.encode(messages,return_tensors='pt')
            tokenized_chat = tokenized_chat.to(device=f"cuda:{device_map-1}")
            
            chosen_entropy = 1.5
            model_2.generation_config.assistant_confidence_threshold = 0.4
            outputs = model.generate(tokenized_chat,max_new_tokens=4096,output_scores=True,return_dict_in_generate = True,output_hidden_states=True,assistant_model=model_2,do_sample=False,chosen_entropy=chosen_entropy) 
            output = tokenizer.decode(outputs[0][0][tokenized_chat.shape[-1]:])
            for tensor in outputs.scores:
                single_entropy = float(compute_entropy(tensor))
                entropy.append(single_entropy)
            f_out.write(json.dumps({"input": messages, "output": golden_answer, "answer": output,'reference':outputs.caandidate_list,'valid_result':outputs.valid_result,'matches':outputs.matches,'entropy':entropy,'punish_list':outputs.whole_punish_list,'chosen_entropy':chosen_entropy,'difficulty':difficult},ensure_ascii=False)+ "\n")

            del outputs



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default=None)
    parser.add_argument("--data_path", type=str, default=None)
    parser.add_argument("--output_path", type=str, default=None)
    parser.add_argument("--device_map", type=str, default=None)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6'

    model_path = '../Qwen2.5-72B-Instruct/'
    model_path_2='../Qwen2.5-7B-Instruct/'

    model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto', torch_dtype=torch.bfloat16)
    model_2 = AutoModelForCausalLM.from_pretrained(model_path_2, device_map=f"cuda:2", torch_dtype=torch.bfloat16)

    chosen_entropy = []
    for k in chosen_entropy:
        SpeculativeGenerate(dataset_name='amc',model=model,model_2=model_2,data_path='../KoT/Dataset/AMC/test.jsonl', output_path='./KoTLlmOutput/GSM8K/9.20-AMC-qwen72-overlap'+ str(k) +'.json',device_map=1,reference=False,file_type='json',chosen_entropy=k)
        
        SpeculativeGenerate(dataset_name='aime',model=model,model_2=model_2,data_path='../KoT/Dataset/Aime-2024/test.jsonl', output_path='./KoTLlmOutput/GSM8K/9.20-AIME24-qwen72-overlap'+ str(k) +'.json',device_map=1,reference=False,file_type='json',chosen_entropy=k)

        SpeculativeGenerate(dataset_name='amc',model=model,model_2=model_2,data_path='../KoT/Dataset/Minerva/test.jsonl', output_path='./KoTLlmOutput/GSM8K/9.20-Minerva-qwen72-overlap'+ str(k) +'.json',device_map=1,reference=False,file_type='json',chosen_entropy=k)

        SpeculativeGenerate(dataset_name='math',model=model,model_2=model_2,data_path='../KoT/Dataset/Math500/test.jsonl', output_path='./KoTLlmOutput/GSM8K/9.20-Math-qwen72-overlap'+ str(k) +'.json',device_map=1,reference=False,file_type='json',chosen_entropy=k)
        
        SpeculativeGenerate(dataset_name='amc',model=model,model_2=model_2,data_path='../KoT/Dataset/Olympia/test.jsonl', output_path='./KoTLlmOutput/GSM8K/9.20-Olympia-qwen72-overlap'+ str(k) +'.json',device_map=1,reference=False,file_type='json',chosen_entropy=k)

        SpeculativeGenerate(dataset_name='gpqa',model=model,model_2=model_2,data_path='../KoT/Dataset/GPQA/test.jsonl', output_path='./KoTLlmOutput/GSM8K/9.20-GPQA-qwen72-overlap'+ str(k) +'.json',device_map=1,reference=False,file_type='json',chosen_entropy=k)
