"""
get_data_mp-qwq_32b_sampling-scaling_online
sampling scaling 
ΪrlѧϰķԡȶԶŷǳҪ塣
"""

import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"]="spawn"
os.environ["TOKENIZERS_PARALLELISM"]="false"
import sys
pwd_path = os.path.dirname(__file__)
print(f"pwd_path: {pwd_path}")
sys.path.insert(0, os.path.join(pwd_path, "../../"))
import jsonlines
import torch
from tqdm import tqdm
from openai import OpenAI
from multiprocessing import Process, Manager
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset
from concurrent.futures.process import ProcessPoolExecutor
from concurrent.futures.thread import ThreadPoolExecutor

from utils.MATH_evaluator_list import MATHEvaluator


def process_prompt(example):
    return {
        "messages": [
            {"role": "system", "content": "You are a helpful and harmless assistant. You should think step-by-step."},
            {"role": "user",   "content": example["problem"]}
        ]
    }

def preprocess_dataset(fp):
    ds = load_dataset("json", data_files=fp, split="train")
    ds = ds.map(process_prompt, num_proc=16, batched=False)
    ds = ds.shuffle(seed=4037).shuffle(seed=1)
    return [x["messages"] for x in ds], [x["answer"] for x in ds], [x["problem"] for x in ds]

def worker_process_dynamic(proc_id, task_queue, progress_queue, config):
    evaluator = MATHEvaluator()
    client = OpenAI(
        api_key=config["api_key"],
        base_url=config["base_url"],
        timeout=259200000,
    )

    output_path = os.path.join(config["output_dir"], f"partial_{proc_id}.jsonl")
    with jsonlines.open(output_path, "a") as writer:
        while True:
            try:
                batch_messages, batch_answers, batch_problems = task_queue.get(timeout=10)
            except:
                print(f"Worker {proc_id} finished.")
                break

            # 
            def send_request(messages_entry):
                try:
                    resp = client.chat.completions.create(
                        model=config["base_model_name"],
                        messages=messages_entry,
                        temperature=config["sampling_params"]["temperature"],
                        top_p=config["sampling_params"]["top_p"],
                        max_tokens=config["sampling_params"]["max_tokens"],
                        n=config["sampling_params"]["n"]
                    )
                    return [c.message.content for c in resp.choices]
                except Exception as e:
                    print(f"Worker {proc_id} API error: {e}")
                    return [""] * config["sampling_params"]["n"]

            responses = []
            with ThreadPoolExecutor(max_workers=config["batch_size"]) as exe:
                futures = [exe.submit(send_request, m) for m in batch_messages]
                for f in futures:
                    try:
                        responses.append(f.result())
                    except Exception as e:
                        print(f"Worker {proc_id} future error: {e}")
                        responses.append([""] * config["sampling_params"]["n"])

            # д
            all_preds = []
            all_ans  = []
            for ans, preds in zip(batch_answers, responses):
                all_preds.extend(preds)
                all_ans.extend([ans] * len(preds))

            try:
                correctness = evaluator.score(all_preds, all_ans)
                correctness = [1 if ok else 0 for ok in correctness]
            except:
                print(f"Worker {proc_id} evaluator error: {e}")
                correctness = [0] * len(all_preds)

            idx = 0
            for prob, ans, preds in zip(batch_problems, batch_answers, responses):
                n = len(preds)
                corr_slice = correctness[idx:idx+n]
                correct_num = sum(corr_slice)
                writer.write({
                    "problem": prob,
                    "pred_ans_list": preds,
                    "answer": ans,
                    "correct_num": correct_num,
                    "sample_num": config["sampling_params"]["n"],
                    "correct_ratio": correct_num / config["sampling_params"]["n"],
                    "pred_model": config["base_model_name"],
                    "max_tokens": config["max_tokens"],
                    "pred_max_tokens": config["sampling_params"]["max_tokens"],
                    "correctness": corr_slice,
                })
                idx += n

            progress_queue.put(1)

def merge_results(output_dir, final_path, filtered_path):
    all_data = []
    for fname in os.listdir(output_dir):
        if fname.startswith("partial_"):
            with jsonlines.open(os.path.join(output_dir, fname)) as reader:
                all_data.extend(list(reader))
    # дļ
    with jsonlines.open(final_path, "w") as writer:
        writer.write_all(all_data)
    # ɸѡȷ < 1 
    filtered = [d for d in all_data if d["correct_ratio"] < 1]
    with jsonlines.open(filtered_path, "w") as writer:
        writer.write_all(filtered)

if __name__ == "__main__":
    pwd_path = os.path.dirname(__file__)
    config = {
        "input_data_path":    os.path.join(pwd_path, "../../data/step0_init_data_open-rs.jsonl"),
        "output_dir":         os.path.join(pwd_path, "../../data/step1/partials-r1-1.5b-sampling-scaling"),
        "final_output_path":  os.path.join(pwd_path, "../../data/step1/step1_init_data-r1-1.5b-sampling-scaling.jsonl"),
        "filtered_output_path": os.path.join(pwd_path, "../../data/step1/step1_filtered_init_data-r1-1.5b-sampling-scaling.jsonl"),
        "api_key":            os.environ.get("OPENAI_API_KEY", "asd"),
        "base_url":           "http://127.0.0.1:41320/v1",
        "base_model_name":    "r1-1.5b",
        "max_tokens":         32*1024,
        "batch_size":         1,
        "sampling_params": {
            "top_p":      1,
            "temperature": 1,
            "max_tokens":32*1024,
            "n":          4096
        }
    }


    # زԤ
    messages, answers, problems = preprocess_dataset(config["input_data_path"])
    os.makedirs(config["output_dir"], exist_ok=True)

    # ַ񵽶
    manager = Manager()
    task_q     = manager.Queue()
    progress_q = manager.Queue()
    for i in range(0, len(messages), config["batch_size"]):
        task_q.put((messages[i:i+config["batch_size"]],
                    answers[i:i+config["batch_size"]],
                    problems[i:i+config["batch_size"]]))

    #  N 
    num_procs = 6   # ò󣬴һʹGPUʸһЩ
    procs = []
    for pid in range(num_procs):
        p = Process(target=worker_process_dynamic, args=(pid, task_q, progress_q, config))
        p.start()
        procs.append(p)

    # ȫֽ
    total_tasks = (len(messages) + config["batch_size"] - 1) // config["batch_size"]
    with tqdm(total=total_tasks, desc="Global Progress") as pbar:
        done = 0
        while done < total_tasks:
            try:
                done += progress_q.get(timeout=1)
                pbar.update(1)
            except:
                pass

    for p in procs: p.join()

    # ϲ
    merge_results(config["output_dir"],
                  config["final_output_path"],
                  config["filtered_output_path"])
