"""
32B stage1 еĲ1˵ݻϼʹòΪ8ݡ
"""
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"]="spawn"
os.environ["TOKENIZERS_PARALLELISM"]="false"
import sys
pwd_path = os.path.dirname(__file__)
print(f"pwd_path: {pwd_path}")
sys.path.insert(0, os.path.join(pwd_path, "../../"))
import jsonlines
import torch
from tqdm import tqdm
from multiprocessing import Process, Manager
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset
from concurrent.futures.process import ProcessPoolExecutor


from utils.MATH_evaluator_list import MATHEvaluator

def process_prompt(example):
    messages = [
        {"role": "system", "content": "You are a helpful and harmless assistant. You should think step-by-step."},
        {"role": "user", "content": example["problem"]}
    ]
    return {"messages": messages}

def preprocess_dataset(fp):
    dataset = load_dataset("json", data_files=fp, split="train")
    dataset = dataset.map(process_prompt, num_proc=16, batched=False)
    dataset.shuffle(seed=4037)
    dataset.shuffle(seed=1)
    #  messages, answers, problems
    return [x["messages"] for x in dataset], [x["answer"] for x in dataset], [x["problem"] for x in dataset]

def check(evaluator, pred_ans, real_ans):
    correctness = evaluator.score(pred_ans, real_ans)
    return correctness

def worker_process_dynamic(proc_id, task_queue, progress_queue, config):
    # ȫCUDA_VISIBLE_DEVICEStensor_parallelȷ̵GPU
    global_cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
    if global_cuda_visible is not None:
        available_gpus = [x.strip() for x in global_cuda_visible.split(",") if x.strip() != ""]
    else:
        available_gpus = [str(i) for i in range(torch.cuda.device_count())]
    total_gpus = len(available_gpus)
    assigned_gpus = []
    for i in range(config["tensor_parallel"]):
        assigned_index = (proc_id * config["tensor_parallel"] + i) % total_gpus
        assigned_gpus.append(available_gpus[assigned_index])
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(assigned_gpus)
    
    # ʼģ
    if "4bit" in config["base_model_name"]:
        model = LLM(
            model=config["base_model_name"],
            trust_remote_code=True,
            dtype="auto",
            tensor_parallel_size=config["tensor_parallel"],
            max_model_len=config["max_tokens"],
            max_seq_len_to_capture=config["max_tokens"],
            gpu_memory_utilization=0.9,
            max_num_seqs=32,
            quantization="bitsandbytes",
            load_format="bitsandbytes",
        )
    else:
        model = LLM(
            model=config["base_model_name"],
            trust_remote_code=True,
            dtype="auto",
            tensor_parallel_size=config["tensor_parallel"],
            max_model_len=config["max_tokens"],
            max_seq_len_to_capture=config["max_tokens"],
            gpu_memory_utilization=0.9,
            max_num_seqs=32,
        )
        
    sampling_params = SamplingParams(
        top_p=0.95,
        temperature=0.6,
        max_tokens=config["pred_max_tokens"],
        n=config["sample_num"]
    )
    evaluator = MATHEvaluator()
    
    output_path = os.path.join(config["output_dir"], f"partial_{proc_id}.jsonl")
    with jsonlines.open(output_path, "a") as writer:
        while True:
            try:
                # ӹлȡһbatch
                batch_task = task_queue.get(timeout=10)
            except Exception:
                print(f"{proc_id}")
                break  # ʱΪѿգ˳ѭ
            batch_messages, batch_answers, batch_problems = batch_task
            # ģ
            try:
                responses = model.chat(batch_messages, sampling_params=sampling_params, use_tqdm=False)
            except:
                responses = [None]*len(batch_messages)
                print(f"{proc_id}ѣ{len(batch_messages)}")
            # ռÿԤͶӦ
            batch_results = []  # ÿΪ (problem, answer, preds)
            for messages, ans, prob, resp in zip(batch_messages, batch_answers, batch_problems, responses):
                try:
                    preds = [output.text for output in resp.outputs]
                    assert len(preds) == config["sample_num"]
                except Exception:
                    preds = [""] * config["sample_num"]
                batch_results.append((prob, ans, preds))
            
            # ԤͶӦı׼𰸻ܣһԼ correctness
            all_preds = []
            all_ans = []
            for _, ans, preds in batch_results:
                all_preds.extend(preds)
                all_ans.extend([ans] * len(preds))
            try:
                all_correctness = check(evaluator, all_preds, all_ans)
                all_correctness = [1 if x else 0 for x in all_correctness]
            except Exception:
                all_correctness = [0] * len(all_preds)
            
            # ݻܽÿֱдļ
            start_index = 0
            for prob, ans, preds in batch_results:
                num_preds = len(preds)
                sample_correctness = all_correctness[start_index:start_index + num_preds]
                correct_num = sum(sample_correctness)
                writer.write({
                    "problem": prob,
                    "pred_ans_list": preds,
                    "answer": ans,
                    "correct_num": correct_num,
                    "sample_num": config["sample_num"],
                    "correct_ratio": correct_num / config["sample_num"],
                    "pred_model": config["base_model_name"],
                    "max_tokens": config["max_tokens"],
                    "pred_max_tokens": config["pred_max_tokens"],
                    "correctness": sample_correctness,
                })
                start_index += num_preds

            progress_queue.put(1)
    print(f"{proc_id}")

def merge_results(output_dir, final_path, filtered_path):
    all_data = []
    for fname in os.listdir(output_dir):
        if fname.startswith("partial_"):
            with jsonlines.open(os.path.join(output_dir, fname)) as reader:
                partial_data=list(reader)
                all_data.extend(partial_data)
                print(f"{fname}: {len(partial_data)}")
    
    # Write final output
    with jsonlines.open(final_path, "w") as writer:
        writer.write_all(all_data)
    print(f"all_data: {len(all_data)}")
    
    # Write filtered output
    filtered_data = [d for d in all_data if d["correct_ratio"] < 1]
    print(f"filtered_data: {len(filtered_data)}")
    with jsonlines.open(filtered_path, "w") as writer:
        writer.write_all(filtered_data)

if __name__ == "__main__":
    # Configuration parameters
    model_name="qwq_32b"
    model_dict={
        "r1_7b": "/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-7B",
        "r1_1.5b": "/home/cwy/LLM/DeepSeek-R1-Distill-Qwen-1.5B",
        "qwq_32b": "/home/cwy/LLM/QwQ-32B",
        "qwq_32b_4bit": "/home/cwy/LLM/QwQ-32B-unsloth-bnb-4bit",
    }
    # model_dict={
    #     "r1_7b": "/data2/cuiwenyao/LLM/DeepSeek-R1-Distill-Qwen-7B",
    #     "r1_1.5b": "/data2/cuiwenyao/LLM/DeepSeek-R1-Distill-Qwen-1.5B",
    #     "qwq_32b": "/data2/cuiwenyao/LLM/QwQ-32B",
    # }
    config = {
        "input_data_path": os.path.join(pwd_path, f"../../data/step1/step1_filtered_init_data-qwq_32b.jsonl"),
        "output_dir": os.path.join(pwd_path, f"../../data/step1/partials-{model_name}-stage2"),
        "final_output_path": os.path.join(pwd_path, f"../../data/step1/step1_init_data-{model_name}-stage2.jsonl"),
        "filtered_output_path": os.path.join(pwd_path, f"../../data/step1/step1_filtered_init_data-{model_name}-stage2.jsonl"),
        "base_model_name": model_dict[model_name],
        "max_tokens": 17*1024,
        "pred_max_tokens": 16*1024,
        "sample_num": 8,
        "max_data_num": -1,  # -1ʾȫ
        "tensor_parallel": 4,
        "batch_size": 32  # ʵÿbatchĴС,עʵvllmbatch size, vllm batch sizemax_num_seqs 
    }
    
    # 
    messages, answers, problems = preprocess_dataset(config["input_data_path"])
    if config["max_data_num"] > 0:
        messages = messages[:config["max_data_num"]]
        answers = answers[:config["max_data_num"]]
        problems = problems[:config["max_data_num"]]
    
    manager = Manager()
    progress_queue = manager.Queue()
    task_queue = manager.Queue()
    tash_len=0
    # ݷֳbatchУÿ batch_messages, batch_answers, batch_problems
    for i in tqdm(range(0, len(messages), config["batch_size"])):
        batch_messages = messages[i:i+config["batch_size"]]
        batch_answers = answers[i:i+config["batch_size"]]
        batch_problems = problems[i:i+config["batch_size"]]
        task_queue.put((batch_messages, batch_answers, batch_problems))
        tash_len+=1
    print(f"task_queue finish: {tash_len}")
    
    num_gpus = torch.cuda.device_count()
    assert num_gpus % config["tensor_parallel"] == 0
    config["num_processes"] = num_gpus // config["tensor_parallel"]
    os.makedirs(config["output_dir"], exist_ok=True)
    
    processes = []
    for pid in range(config["num_processes"]):
        p = Process(
            target=worker_process_dynamic,
            args=(pid, task_queue, progress_queue, config)
        )
        p.start()
        processes.append(p)
    
    # ʾ
    with tqdm(total=tash_len, desc="Global Progress") as pbar:
        processed = 0
        while processed < tash_len:
            try:
                update = progress_queue.get(timeout=1)
                processed += update
                pbar.update(update)
            except Exception:
                pass
            
    for p in processes:
        p.join()
    
    merge_results(config["output_dir"], config["final_output_path"], config["filtered_output_path"])
