import torch
import pickle
# from time import time
import time
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset
from vllm.sampling_params import GuidedDecodingParams
from trl.trainer.utils import pad
 # 假设 accuracy_reward 函数已经定义
import random
import numpy as np
import os

from open_r1.rewards_gsm import extract_answer_from_model_output, extract_answer_from_dataset, extract_last_number, extract_single_number

from trl.data_utils import maybe_apply_chat_template


from accelerate.utils import set_seed

def is_ok(predicted,solution):
    is_correct=False
    if predicted == solution:  # Exact match
        is_correct = True
    else:
                                # Try single number matching
        pred_num = extract_single_number(predicted)
        exp_num = float(solution)
        if pred_num is not None and exp_num is not None and pred_num == exp_num:
            is_correct = True
        else:
                                    # Try last number matching
            pred_num = extract_last_number(predicted)
            exp_num = extract_last_number(solution)
            is_correct = (pred_num is not None and exp_num is not None and pred_num == exp_num)
                                    

    predicted2 = extract_answer_from_model_output(predicted)
    if predicted2 is not None:
        if predicted2 == solution:  # Exact match
            is_correct = True
        else:
                                    # Try single number matching
            pred_num = extract_single_number(predicted2)
            exp_num = float(solution)
            if pred_num is not None and exp_num is not None and pred_num == exp_num:
                is_correct = True
            else:
                                        # Try last number matching
                pred_num = extract_last_number(predicted2)
                exp_num = extract_last_number(solution)
                is_correct = (pred_num is not None and exp_num is not None and pred_num == exp_num)
    return is_correct
def hyper_parameters():
    parser = ArgumentParser(description='test')

    parser.add_argument('--model_dir', type=str, default="./Qwen2.5-0.5B-Instruct")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--dataset_dir', type=str, default="./gsm8k")
    parser.add_argument('--batch_size', type=int,default=8)
    parser.add_argument('--output', type=str,default='./output/Qwen-7B')

    opt = parser.parse_args()
    return opt

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

if __name__ == '__main__':    
    hps = hyper_parameters()
    print(hps)
    set_seed(hps.seed)
    tokenizer = AutoTokenizer.from_pretrained(hps.model_dir)
    device = "cuda:0"
    llm = LLM(
            model=hps.model_dir,
            device=device,
            gpu_memory_utilization=0.9,
            dtype="auto",
            enable_prefix_caching=False,
            max_model_len=None)


    # 设置生成参数
    guided_decoding = GuidedDecodingParams(backend="outlines", regex=None)
    sampling_params = SamplingParams(
        temperature=1.0,
        max_tokens=1024,
        guided_decoding=guided_decoding,
        n=16,  # 每次生成 16 个候选
    )
    datas = load_dataset(hps.dataset_dir)['train']
    system_prompt="A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>\n<answer> answer here </answer>." 
    def make_conversation(example):
        prompt = []
        if system_prompt is not None:
            prompt.append({"role": "system", "content": system_prompt})

        prompt.append({"role": "user", "content": example["question"]})
        return {"prompt": prompt}
    datas = datas.map(make_conversation)
    solution=datas['solution']
    question = datas['question'] 
    prompt = [maybe_apply_chat_template(example, tokenizer)['prompt'] for example in datas]
    data = [{"question": q, "solution": s, "prompt": p} for q, s, p in zip(question, solution, prompt)]
    train_dataloader = DataLoader(data, batch_size=hps.batch_size, shuffle=False)
    acc=[]
    begin=time.time()
    for batch in tqdm(train_dataloader):
        solution=batch["solution"]
        prompt=batch["prompt"]
        all_outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=False)
        completion_ids = []
        for outputs in all_outputs:
            completion_id=[]
            for output in outputs.outputs:
                completion_id.append(output.token_ids)
            completion_ids.append(completion_id)
        predicteds=[]
        for completion_id in completion_ids:
            completion_id = [torch.tensor(ids, device=device) for ids in completion_id]
            predicteds.append(tokenizer.batch_decode(completion_id, skip_special_tokens=True)) 
        for predicted, solu in zip(predicteds, solution):
            count=0
            for pre in predicted:
                s=extract_answer_from_dataset(solu)
                if is_ok(predicted=pre,solution=s):
                    count=count+1
            acc.append(count/len(predicted))
    end=time.time()
    print(end-begin)
    if not os.path.exists(hps.output):
        os.makedirs(hps.output)
    with open(os.path.join(hps.output,"acc.pkl"),"wb")as f:
        pickle.dump(acc,f)

