﻿import os
import sys
pwd_path = os.path.dirname(__file__)
print(f"pwd_path: {pwd_path}")
sys.path.insert(0, os.path.join(pwd_path, "../../"))
import jsonlines
import torch
from multiprocessing import Process
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset
from utils.MATH_evaluator_list import MATHEvaluator

os.environ["TOKENIZERS_PARALLELISM"]="false"

def process_prompt(example):
    messages = [
        {"role": "system", "content": "You are a helpful and harmless assistant. You should think step-by-step."},
        {"role": "user", "content": example["problem"]}
    ]
    return {"messages": messages}

def preprocess_dataset(fp):
    dataset = load_dataset("json", data_files=fp, split="train")
    dataset = dataset.map(process_prompt, num_proc=1, batched=False)
    dataset.shuffle()
    return [x["messages"] for x in dataset], [x["answer"] for x in dataset], [x["problem"] for x in dataset]

def check(evaluator, pred_ans, real_ans):
    correctness = evaluator.score(pred_ans, real_ans)
    return correctness


def merge_results(output_dir, final_path, filtered_path):
    all_data = []
    for fname in os.listdir(output_dir):
        if fname.startswith("partial_"):
            with jsonlines.open(os.path.join(output_dir, fname)) as reader:
                partial_data=list(reader)
                all_data.extend(partial_data)
                print(f"{fname}: {len(partial_data)}")
    
    # Write final output
    with jsonlines.open(final_path, "w") as writer:
        writer.write_all(all_data)
    print(f"all_data: {len(all_data)}")
    
    # Write filtered output
    filtered_data = [d for d in all_data if d["correct_ratio"] < 1]
    print(f"filtered_data: {len(filtered_data)}")
    with jsonlines.open(filtered_path, "w") as writer:
        writer.write_all(filtered_data)

if __name__ == "__main__":
    # Configuration parameters
    # model_name="r1_1.5b"
    # model_dict={
    #     "r1_7b": "/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-7B",
    #     "r1_1.5b": "/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-1.5B",
    #     "qwq_32b": "/home/cwy/LLM/QwQ-32B",
    # }
    # model_dict={
    #     "r1_7b": "/data2/cuiwenyao/LLM/DeepSeek-R1-Distill-Qwen-7B",
    #     "r1_1.5b": "/data2/cuiwenyao/LLM/DeepSeek-R1-Distill-Qwen-1.5B",
    #     "qwq_32b": "/data2/cuiwenyao/LLM/QwQ-32B",
    # }
    # config = {
    #     "output_dir": os.path.join(pwd_path, f"../../data/step1/partials-{model_name}"),
    #     "final_output_path": os.path.join(pwd_path, f"../../data/step1/step1_init_data-{model_name}.jsonl"),
    #     "filtered_output_path": os.path.join(pwd_path, f"../../data/step1/step1_filtered_init_data-{model_name}.jsonl"),
    # }
    config = {
        "output_dir": os.path.join(pwd_path, f"../../data/step1/partials-r1-1.5b-sampling-scaling"),
        "final_output_path": os.path.join(pwd_path, f"../../data/step1/partials-r1-1.5b-sampling-scaling.jsonl"),
        "filtered_output_path": os.path.join(pwd_path, f"../../data/step1/partials-r1-1.5b-sampling-scaling_filtered.jsonl"),
    }
    merge_results(config["output_dir"], 
                 config["final_output_path"],
                 config["filtered_output_path"])
