import os
import sys
pwd_path=os.path.dirname(__file__)
print(f"pwd_path: {pwd_path}")
sys.path.insert(0, os.path.join(pwd_path, "../../"))


import torch
import random
import jsonlines
from datasets import load_dataset, concatenate_datasets
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

from utils.MATH_evaluator_list import MATHEvaluator



def check(evaluator, pred_ans, real_ans):
    correctness = evaluator.score(pred_ans, real_ans)
    return correctness


"""
"/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-1.5B"
"/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-7B"
"/data2/cuiwenyao/LLM/QwQ-32B"
"""
input_data_path=os.path.join(pwd_path, "../../data/step0_init_data.jsonl")
output_data_path=os.path.join(pwd_path, "../../data/step1_init_data.jsonl")
filtered_output_data_path=os.path.join(pwd_path, "../../data/step1_filtered_init_data.jsonl")
base_model_name="/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-7B"
max_tokens=32*1024
pred_max_tokens=2*1024
sample_num=3
max_data_num=1000



def process_prompt(example):
    messages = [
        {
            "role": "system",
            "content": "You are a helpful and harmless assistant. You should think step-by-step.",
        },
        {
            "role": "user", 
            "content": example["problem"]
        },
    ]
    return {"messages": messages}

def preprocess_dataset(fp):
    dataset=load_dataset("json", data_files=fp, split="train")
    dataset=dataset.map(process_prompt, num_proc=64, batched = False)
    dataset.shuffle()
    messages_list=[x["messages"] for x in dataset]
    answer_list=[x["answer"] for x in dataset]
    problem_list=[x["problem"] for x in dataset]
    return messages_list, answer_list, problem_list


messages_list, answer_list, problem_list= preprocess_dataset(input_data_path)
print(f"样本数量: {len(messages_list)}=={len(answer_list)}=={len(problem_list)}")
if max_data_num>0:
    messages_list=messages_list[:max_data_num]
    answer_list=answer_list[:max_data_num]
    problem_list=problem_list[:max_data_num]


tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = LLM(
    model=base_model_name,
    trust_remote_code=True,
    dtype="auto",
    tensor_parallel_size=torch.cuda.device_count(),
    max_model_len=max_tokens,
    swap_space=32,
    gpu_memory_utilization=0.9,
)
sampling_params = SamplingParams(
    temperature=1,
    max_tokens=pred_max_tokens,
    n=sample_num,
)
evaluator=MATHEvaluator()


responses = model.chat(
    messages=messages_list,
    sampling_params=sampling_params
)

all_data=[]
all_data_filtered=[]
for problem, response, answer in zip(problem_list, responses, answer_list):
    pred_ans_list=[x.text for x in response.outputs]
    real_ans_list=[answer]*len(pred_ans_list)
    correctness_list = check(evaluator, pred_ans_list, real_ans_list)
    correctness_list = [1 if x == True else 0 for x in correctness_list]
    correct_num = sum(correctness_list)
    correct_ratio = correct_num/len(correctness_list)
    temp_data={
        "problem": problem,
        "pred_ans_list": pred_ans_list,
        "answer": answer,
        "correct_num": correct_num,
        "sample_num": sample_num,
        "correct_ratio": correct_ratio,
        "pred_model": base_model_name,
        "max_tokens": max_tokens,
        "pred_max_tokens": pred_max_tokens,
        "correctness": correctness_list,
    }
    all_data.append(temp_data)
    if correct_ratio<1:
        all_data_filtered.append(temp_data)



jsonlines.open(output_data_path, mode="w").write_all(all_data)
jsonlines.open(filtered_output_data_path, mode="w").write_all(all_data_filtered)
print(f"all_data: {len(all_data)}")
print(f"all_data_filtered: {len(all_data_filtered)}")