import os
os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
os.environ["http_proxy"]=""
os.environ["https_proxy"]=""
# os.environ["http_proxy"]="http://127.0.0.1:7890"
# os.environ["https_proxy"]="http://127.0.0.1:7890"
import sys
pwd_path=os.path.dirname(__file__)
print(f"pwd_path: {pwd_path}")
sys.path.insert(0, os.path.join(pwd_path, "../../"))


import jsonlines
from datasets import load_dataset

from random import shuffle


def check(evaluator, pred_ans, real_ans):
    # print(len(pred_ans), len(real_ans))
    correctness = evaluator.score(pred_ans, real_ans)
    return correctness

def unique_list(messages_list, answer_list, problem_list):
    unique_messages_list = []
    unique_answer_list = []
    unique_problem_list = []
    seen = set()

    for msg, ans, prob in zip(messages_list, answer_list, problem_list):
        if prob not in seen:
            seen.add(prob)
            unique_messages_list.append(msg)
            unique_answer_list.append(ans)
            unique_problem_list.append(prob)

    # ԭб߼ʹµб
    messages_list = unique_messages_list
    answer_list = unique_answer_list
    problem_list = unique_problem_list
    return messages_list, answer_list, problem_list
    


"""
"/data2/cuiwenyao/LLM/DeepSeek-R1-Distill-Qwen-1.5B"
"/data2/cuiwenyao/LLM/DeepSeek-R1-Distill-Qwen-7B"
"/data2/cuiwenyao/LLM/QwQ-32B"
"""

dataset_name_1="SynthLabsAI/Big-Math-RL-Verified"
dataset_name_2="open-r1/OpenR1-Math-220k"
dataset_name_list=[dataset_name_1, dataset_name_2]

output_path=os.path.join(pwd_path, "../../data/step0_init_data.jsonl")



def process_prompt(example):
    messages = [
        {
            "role": "system",
            "content": "You are a helpful and harmless assistant. You should think step-by-step.",
        },
        {
            "role": "user", 
            "content": example["problem"]
        },
    ]
    return {"messages": messages}

def preprocess_dataset(dataset_name):
    dataset=load_dataset(dataset_name, split="train")
    dataset=dataset.map(process_prompt, num_proc=64, batched = False)
    messages_list=[x["messages"] for x in dataset]
    answer_list=[x["answer"] for x in dataset]
    problem_list=[x["problem"] for x in dataset]
    return messages_list, answer_list, problem_list


messages_list, answer_list, problem_list= [], [], []
for dataset_name in dataset_name_list:
    m, a, p=preprocess_dataset(dataset_name)
    messages_list.extend(m)
    answer_list.extend(a)
    problem_list.extend(p)
messages_list, answer_list, problem_list=unique_list(messages_list, answer_list, problem_list)
print(f"ȥغ: {len(messages_list)}=={len(answer_list)}=={len(problem_list)}")


all_data=[]
for problem, answer in zip(problem_list, answer_list):
    temp_data={
        "problem": problem,
        "answer": answer,
    }
    all_data.append(temp_data)

shuffle(all_data)
jsonlines.open(output_path, mode="w").write_all(all_data)
print(f"all_data: {len(all_data)}")