import os
import argparse
import json
import random
import multiprocessing as mp
from tqdm import tqdm
import pandas as pd
import re
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# ̶
random.seed(42)

# GPQA ʾģ
QUERY_TEMPLATE = """
{question}

A: {A}
B: {B}
C: {C}
D: {D}

Please reason step by step, and put your final answer within \\boxed{{}}.
Please only provide the letter of the answer in the box.
""".strip()

BOXED_ANSWER_RE = re.compile(r"\\boxed\{([A-D])\}")

def build_prompt(example, tokenizer, system_prompt: str):
    """
    ģ
    """
    prompt = QUERY_TEMPLATE.format(
        question=example["question"],
        A=example["choices"][0],
        B=example["choices"][1],
        C=example["choices"][2],
        D=example["choices"][3],
    )
    if system_prompt == "qwen":
        messages = [
            {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
            {"role": "user", "content": prompt},
        ]
    elif system_prompt == "deepseek":
        messages = [
            {"role": "user", "content": prompt + "\nPlease reason step by step, and put your final answer within \\boxed{}."},
        ]
    else:  # none
        messages = [{"role": "user", "content": prompt}]
    chat = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return chat

def worker_process(proc_id, task_queue, progress_queue, args, output_dir):
    #  GPU
    global_cuda = os.environ.get("CUDA_VISIBLE_DEVICES")
    if global_cuda:
        gpus = [g.strip() for g in global_cuda.split(",") if g.strip()]
    else:
        gpus = [str(i) for i in range(torch.cuda.device_count())]
    tp = args.tensor_parallel
    assigned = [gpus[(proc_id * tp + i) % len(gpus)] for i in range(tp)]
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(assigned)
    print(f"[Process {proc_id}] Using GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")

    # ģ tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    model = LLM(
        model=args.model_name_or_path,
        trust_remote_code=True,
        dtype="bfloat16",
        tensor_parallel_size=args.tensor_parallel,
        max_model_len=args.max_tokens + 1024,
        max_seq_len_to_capture=args.max_tokens,
        gpu_memory_utilization=0.9,
        max_num_seqs=256,
    )
    stop_words = ["<|im_end|>", "<|endoftext|>"]
    if args.decode == "sample":
        sampling_params = SamplingParams(
            top_p=0.95,
            temperature=0.6,
            max_tokens=args.max_tokens,
            stop=stop_words,
            n=1,
        )
    else:  # greedy
        sampling_params = SamplingParams(
            temperature=0.0,
            max_tokens=args.max_tokens,
            stop=stop_words,
            n=1,
        )

    # 򿪱ļ
    partial_path = os.path.join(output_dir, f"partial_{proc_id}.jsonl")
    fout = open(partial_path, "w", encoding="utf-8")

    while True:
        try:
            batch = task_queue.get(timeout=10)
        except Exception:
            break

        prompts = [build_prompt(ex, tokenizer, args.system_prompt) for ex in batch]
        responses = model.generate(prompts, sampling_params, use_tqdm=False)

        for ex, resp in zip(batch, responses):
            text = resp.outputs[0].text
            m = BOXED_ANSWER_RE.search(text)
            pred = m.group(1) if m else None
            fout.write(json.dumps({
                "question": ex["question"],
                "choices": ex["choices"],
                "label": ex["label"],
                "prediction": pred,
                "stop_reason": resp.outputs[0].stop_reason
            }, ensure_ascii=False) + "\n")
            fout.flush()
            progress_queue.put(1)

    fout.close()

def merge_results(output_dir, final_output_path):
    all_results = []
    for fname in os.listdir(output_dir):
        if fname.startswith("partial_") and fname.endswith(".jsonl"):
            path = os.path.join(output_dir, fname)
            with open(path, "r", encoding="utf-8") as fin:
                for line in fin:
                    all_results.append(json.loads(line))
            print(f"Merged {fname}")
    with open(final_output_path, "w", encoding="utf-8") as fout:
        for r in all_results:
            fout.write(json.dumps(r, ensure_ascii=False) + "\n")
    return all_results

def evaluate_results(results):
    total = len(results)
    correct = sum(1 for r in results if r["prediction"] == r["label"])
    acc = correct / total * 100
    print(f"GPQA Evaluation Accuracy: {acc:.2f}% ({correct}/{total})")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--csv_path", type=str, required=True, help="GPQA CSV file path")
    parser.add_argument("--target_path", type=str, required=True, help="Output directory")
    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--max_tokens", type=int, default=32*1024)
    parser.add_argument("--decode", choices=["greedy","sample"], default="greedy")
    parser.add_argument("--system_prompt", choices=["none","qwen","deepseek"], default="none")
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--tensor_parallel", type=int, default=1)
    parser.add_argument("--variant", type=str, default="diamond", help="GPQA variant (diamond, etc.)")
    parser.add_argument("--n_repeats", type=int, default=1, help="ÿظ")
    args = parser.parse_args()

    os.makedirs(args.target_path, exist_ok=True)
    df = pd.read_csv(args.csv_path)
    rng = random.Random(42)

    # б
    examples = []
    for _, row in tqdm(df.iterrows()):
        choices = [
            row["Correct Answer"],
            row["Incorrect Answer 1"],
            row["Incorrect Answer 2"],
            row["Incorrect Answer 3"],
        ]
        perm = rng.sample(range(4), 4)
        permuted = [choices[i] for i in perm]
        correct_letter = "ABCD"[permuted.index(row["Correct Answer"])]
        examples.append({
            "question": row["Question"],
            "choices": permuted,
            "label": correct_letter,
        })
    examples = examples * args.n_repeats

    num_samples = len(examples)
    print(f"Total examples: {num_samples}")

    num_gpus = torch.cuda.device_count()
    if num_gpus < args.tensor_parallel:
        raise ValueError("Not enough GPUs for the requested tensor_parallel setting.")
    num_procs = num_gpus // args.tensor_parallel
    print(f"Launching {num_procs} processes...")

    args.batch_size = min(args.batch_size, num_samples // num_procs)
    partial_dir = os.path.join(args.target_path, "partials")
    if os.path.isdir(partial_dir):
        import shutil
        shutil.rmtree(partial_dir)
    os.makedirs(partial_dir, exist_ok=True)

    manager = mp.Manager()
    task_q = manager.Queue()
    prog_q = manager.Queue()

    # 
    for i in range(0, num_samples, args.batch_size):
        task_q.put(examples[i:i + args.batch_size])

    procs = []
    for pid in range(num_procs):
        p = mp.Process(
            target=worker_process,
            args=(pid, task_q, prog_q, args, partial_dir)
        )
        p.start()
        procs.append(p)

    total_batches = (num_samples + args.batch_size - 1) // args.batch_size
    with tqdm(total=total_batches, desc="GPQA Batches") as pbar:
        done = 0
        while done < total_batches:
            try:
                done += prog_q.get(timeout=1)
                pbar.update(1)
            except:
                pass

    for p in procs:
        p.join()

    final_path = os.path.join(
        args.target_path,
        f"GPQA-{args.variant}-L{args.max_tokens}-D{args.decode}.jsonl"
    )
    results = merge_results(partial_dir, final_path)
    evaluate_results(results)
