import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"]="spawn"
os.environ["TOKENIZERS_PARALLELISM"]="false"
import sys
pwd_path = os.path.dirname(__file__)
print(f"pwd_path: {pwd_path}")
sys.path.insert(0, os.path.join(pwd_path, "../../"))
import jsonlines
import torch
from tqdm import tqdm
from multiprocessing import Process, Manager
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset

from utils.MATH_evaluator_list import MATHEvaluator

def process_prompt(example):
    messages = [
        {"role": "system", "content": "You are a helpful and harmless assistant. You should think step-by-step."},
        {"role": "user", "content": example["problem"]}
    ]
    return {"messages": messages}

def preprocess_dataset(fp):
    dataset = load_dataset("json", data_files=fp, split="train")
    dataset = dataset.map(process_prompt, num_proc=16, batched=False)
    dataset.shuffle(seed=4037)
    dataset.shuffle(seed=1)
    # ·µ»Ø messages, answers, problems
    return [x["messages"] for x in dataset], [x["answer"] for x in dataset], [x["problem"] for x in dataset]

def check(evaluator, pred_ans, real_ans):
    correctness = evaluator.score(pred_ans, real_ans)
    return correctness

def worker_process_dynamic(proc_id, task_queue, progress_queue, config):
    # ¸ù¾ÝÈ«¾ÖCUDA_VISIBLE_DEVICESºÍtensor_parallel²ÎÊýÈ·¶¨±¾½ø³ÌµÄGPU·ÖÅä
    global_cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
    if global_cuda_visible is not None:
        available_gpus = [x.strip() for x in global_cuda_visible.split(",") if x.strip() != ""]
    else:
        available_gpus = [str(i) for i in range(torch.cuda.device_count())]
    total_gpus = len(available_gpus)
    assigned_gpus = []
    for i in range(config["tensor_parallel"]):
        assigned_index = (proc_id * config["tensor_parallel"] + i) % total_gpus
        assigned_gpus.append(available_gpus[assigned_index])
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(assigned_gpus)
    
    # ³õÊ¼»¯Ä£ÐÍ
    model = LLM(
        model=config["base_model_name"],
        trust_remote_code=True,
        dtype="auto",
        tensor_parallel_size=config["tensor_parallel"],
        max_model_len=config["max_tokens"],
        gpu_memory_utilization=0.95,
        max_num_seqs=64,
    )
    sampling_params = SamplingParams(
        temperature=1,
        max_tokens=config["pred_max_tokens"],
        n=config["sample_num"]
    )
    evaluator = MATHEvaluator()
    
    output_path = os.path.join(config["output_dir"], f"partial_{proc_id}.jsonl")
    with jsonlines.open(output_path, "a") as writer:
        while True:
            try:
                # ´Ó¹²Ïí¶ÓÁÐÖÐ»ñÈ¡Ò»¸öbatchÈÎÎñ
                batch_task = task_queue.get(timeout=10)
            except Exception:
                print(f"ÈÎÎñ{proc_id}½áÊø")
                break  # ³¬Ê±ÈÏÎª¶ÓÁÐÒÑ¿Õ£¬ÍË³öÑ­»·
            batch_messages, batch_answers, batch_problems = batch_task
            # µ÷ÓÃÄ£ÐÍÅúÁ¿Éú³É
            try:
                responses = model.chat(batch_messages, sampling_params=sampling_params, use_tqdm=False)
            except:
                responses = [None]*len(batch_messages)
                print(f"ÈÎÎñ{proc_id}Óöµ½À§ÄÑ£¬ÉáÆú{len(batch_messages)}ÌõÊý¾Ý")
            # ±éÀúµ±Ç°batchÖÐÃ¿¸öÑù±¾µÄ½á¹û
            for messages, ans, prob, resp in zip(batch_messages, batch_answers, batch_problems, responses):
                try:
                    preds = [output.text for output in resp.outputs]
                    assert len(preds) == config["sample_num"]
                    correctness = check(evaluator, preds, [ans] * len(preds))
                    correctness = [1 if x else 0 for x in correctness]
                    assert len(correctness) == config["sample_num"]
                except Exception:
                    preds = [""] * config["sample_num"]
                    correctness = [0] * config["sample_num"]
                correct_num = sum(correctness)
                writer.write({
                    "problem": prob,
                    "pred_ans_list": preds,
                    "answer": ans,
                    "correct_num": correct_num,
                    "sample_num": config["sample_num"],
                    "correct_ratio": correct_num / config["sample_num"],
                    "pred_model": config["base_model_name"],
                    "max_tokens": config["max_tokens"],
                    "pred_max_tokens": config["pred_max_tokens"],
                    "correctness": correctness,
                })
            progress_queue.put(1)
    print(f"ÈÎÎñ{proc_id}½áÊø")

def merge_results(output_dir, final_path, filtered_path):
    all_data = []
    for fname in os.listdir(output_dir):
        if fname.startswith("partial_"):
            with jsonlines.open(os.path.join(output_dir, fname)) as reader:
                partial_data=list(reader)
                all_data.extend(partial_data)
                print(f"{fname}: {len(partial_data)}")
    
    # Write final output
    with jsonlines.open(final_path, "w") as writer:
        writer.write_all(all_data)
    print(f"all_data: {len(all_data)}")
    
    # Write filtered output
    filtered_data = [d for d in all_data if d["correct_ratio"] < 1]
    print(f"filtered_data: {len(filtered_data)}")
    with jsonlines.open(filtered_path, "w") as writer:
        writer.write_all(filtered_data)

if __name__ == "__main__":
    
    model_name="r1_1.5b"
    # model_dict={
    #     "r1_7b": "/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-7B",
    #     "r1_1.5b": "/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-1.5B",
    #     "qwq_32b": "/home/cwy/LLM/QwQ-32B",
    # }
    model_dict={
        "r1_7b": "/data2/cuiwenyao/LLM/DeepSeek-R1-Distill-Qwen-7B",
        "r1_1.5b": "/data2/cuiwenyao/LLM/DeepSeek-R1-Distill-Qwen-1.5B",
        "qwq_32b": "/data2/cuiwenyao/LLM/QwQ-32B",
    }
    config = {
        "input_data_path": os.path.join(pwd_path, f"../../data/step0_init_data.jsonl"),
        "output_dir": os.path.join(pwd_path, f"../../data/step1/partials-{model_name}"),
        "final_output_path": os.path.join(pwd_path, f"../../data/step1/step1_init_data-{model_name}.jsonl"),
        "filtered_output_path": os.path.join(pwd_path, f"../../data/step1/step1_filtered_init_data-{model_name}.jsonl"),
        "base_model_name": model_dict[model_name],
        "max_tokens": 32*1024,
        "pred_max_tokens": 16*1024,
        "sample_num": 1,
        "max_data_num": -1,  # -1±íÊ¾È«²¿Êý¾Ý
        "tensor_parallel": 1,
        "batch_size": 512  # ¸ù¾ÝÊµ¼ÊÇé¿öÉèÖÃÃ¿¸öbatchµÄ´óÐ¡,×¢ÒâÕâ¸ö²»ÊÇÊµ¼Êvllm´¦ÀíµÄbatch size, ÕæÕýµÄvllm batch sizeÊÇmax_num_seqs 
    }
    
    # ¼ÓÔØÊý¾Ý
    messages, answers, problems = preprocess_dataset(config["input_data_path"])
    if config["max_data_num"] > 0:
        messages = messages[:config["max_data_num"]]
        answers = answers[:config["max_data_num"]]
        problems = problems[:config["max_data_num"]]
    
    manager = Manager()
    progress_queue = manager.Queue()
    task_queue = manager.Queue()
    tash_len=0
    # ½«Êý¾Ý·Ö³Ébatchºó·ÅÈë¶ÓÁÐ£¨Ã¿¸öÈÎÎñ°üº¬ batch_messages, batch_answers, batch_problems£©
    for i in tqdm(range(0, len(messages), config["batch_size"])):
        batch_messages = messages[i:i+config["batch_size"]]
        batch_answers = answers[i:i+config["batch_size"]]
        batch_problems = problems[i:i+config["batch_size"]]
        task_queue.put((batch_messages, batch_answers, batch_problems))
        tash_len+=1
    print(f"task_queue finish: {tash_len}")
    
    num_gpus = torch.cuda.device_count()
    config["num_processes"] = num_gpus // config["tensor_parallel"]
    os.makedirs(config["output_dir"], exist_ok=True)
    
    processes = []
    for pid in range(config["num_processes"]):
        p = Process(
            target=worker_process_dynamic,
            args=(pid, task_queue, progress_queue, config)
        )
        p.start()
        processes.append(p)
    
    # ½ø¶ÈÏÔÊ¾
    
    with tqdm(total=tash_len, desc="Global Progress") as pbar:
        processed = 0
        while processed < tash_len:
            try:
                update = progress_queue.get(timeout=1)
                processed += update
                pbar.update(update)
            except Exception:
                pass
            
    for p in processes:
        p.join()
    
    merge_results(config["output_dir"], config["final_output_path"], config["filtered_output_path"])
