import os
os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
os.environ["http_proxy"]=""
os.environ["https_proxy"]=""
# os.environ["http_proxy"]="http://127.0.0.1:7890"
# os.environ["https_proxy"]="http://127.0.0.1:7890"
import sys
pwd_path=os.path.dirname(__file__)
print(f"pwd_path: {pwd_path}")
sys.path.insert(0, os.path.join(pwd_path, "../../"))


import jsonlines
from datasets import load_dataset

from random import shuffle


def check(evaluator, pred_ans, real_ans):
    # print(len(pred_ans), len(real_ans))
    correctness = evaluator.score(pred_ans, real_ans)
    return correctness

def unique_list(messages_list, answer_list, problem_list):
    unique_messages_list = []
    unique_answer_list = []
    unique_problem_list = []
    seen = set()

    for msg, ans, prob in zip(messages_list, answer_list, problem_list):
        if prob not in seen:
            seen.add(prob)
            unique_messages_list.append(msg)
            unique_answer_list.append(ans)
            unique_problem_list.append(prob)

    # ¸²¸ÇÔ­ÓÐÁÐ±í£¬»òÕß¼ÌÐøÊ¹ÓÃÐÂµÄÁÐ±í
    messages_list = unique_messages_list
    answer_list = unique_answer_list
    problem_list = unique_problem_list
    return messages_list, answer_list, problem_list
    

dataset_name_1="knoveleng/open-rs"
dataset_name_list=[dataset_name_1]

output_path=os.path.join(pwd_path, "../../data/step0_init_data_open-rs.jsonl")



def process_prompt(example):
    messages = [
        {
            "role": "system",
            "content": "You are a helpful and harmless assistant. You should think step-by-step.",
        },
        {
            "role": "user", 
            "content": example["problem"]
        },
    ]
    return {"messages": messages}

def preprocess_dataset(dataset_name):
    dataset=load_dataset(dataset_name, split="train")
    dataset = dataset.filter(lambda example: example["level"] == "Hard")
    dataset=dataset.map(process_prompt, num_proc=64, batched = False)
    messages_list=[x["messages"] for x in dataset]
    answer_list=[x["answer"] for x in dataset]
    problem_list=[x["problem"] for x in dataset]
    level_list=[x["level"] for x in dataset]
    return messages_list, answer_list, problem_list, level_list


messages_list, answer_list, problem_list= [], [], []
for dataset_name in dataset_name_list:
    m, a, p, le=preprocess_dataset(dataset_name)
    messages_list.extend(m)
    answer_list.extend(a)
    problem_list.extend(p)
messages_list, answer_list, problem_list=unique_list(messages_list, answer_list, problem_list)
print(f"È¥ÖØºóÑù±¾ÊýÁ¿: {len(messages_list)}=={len(answer_list)}=={len(problem_list)}")


all_data=[]
for problem, answer in zip(problem_list, answer_list):
    temp_data={
        "problem": problem,
        "answer": answer,
    }
    all_data.append(temp_data)

shuffle(all_data)
jsonlines.open(output_path, mode="w").write_all(all_data)
print(f"all_data: {len(all_data)}")