#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Super GPQA Evaluation (GPQA-style, saves raw model output)
"""
import os, argparse, json, random, multiprocessing as mp, re, shutil
from tqdm import tqdm
import torch, datasets
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# ------------------------ constants & templates -----------------------------
QUERY_TEMPLATE = """
{question}

{options}

Please reason step by step, and put your final answer within \\boxed{{}}.
Please only provide the letter of the answer in the box.
""".strip()

BOXED_ANSWER_RE = re.compile(r"\\boxed\{([A-J])\}")
OPTIONS_LETTERS = list("ABCDEFGHIJ")
CATEGORIES = [
    'Engineering', 'Medicine', 'Science', 'Philosophy', 'Military Science',
    'Economics', 'Management', 'Sociology', 'Literature and Arts', 'History',
    'Agronomy', 'Law', 'Education'
]

# ----------------------------- helpers --------------------------------------
def form_options(opts):
    return "Options are:\n" + "\n".join(f"({l}): {o}" for l, o in zip(OPTIONS_LETTERS, opts))

def build_prompt(ex, tok, sys_prompt):
    prompt_txt = QUERY_TEMPLATE.format(question=ex["question"], options=form_options(ex["choices"]))
    if sys_prompt == "qwen":
        msgs=[{"role":"system","content":"Please reason step by step, and put your final answer within \\boxed{}."},
              {"role":"user","content":prompt_txt}]
    elif sys_prompt == "deepseek":
        msgs=[{"role":"user","content":prompt_txt+"\nPlease reason step by step, and put your final answer within \\boxed{}."}]
    else:
        msgs=[{"role":"user","content":prompt_txt}]
    return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def extract_prediction(text:str):
    m = BOXED_ANSWER_RE.search(text)
    return m.group(1) if m else random.choice(OPTIONS_LETTERS)

# --------------------------- multiprocessing --------------------------------
def worker_process(pid, tq, pq, args, out_dir):
    # GPU binding (same as GPQA)
    gpus = os.getenv("CUDA_VISIBLE_DEVICES", ",".join(map(str, range(torch.cuda.device_count())))).split(",")
    tp=args.tensor_parallel
    os.environ["CUDA_VISIBLE_DEVICES"]=",".join(gpus[(pid*tp+i)%len(gpus)] for i in range(tp))
    print(f"[Process {pid}] GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")

    tok = AutoTokenizer.from_pretrained(args.model_path)
    llm = LLM(model=args.model_path, trust_remote_code=True, dtype="bfloat16",
              tensor_parallel_size=tp, max_model_len=args.max_tokens+1024,
              max_seq_len_to_capture=args.max_tokens, gpu_memory_utilization=0.9, max_num_seqs=256)

    stop_words=["<|im_end|>","<|endoftext|>"]
    sp = SamplingParams(top_p=0.95, temperature=0.6, max_tokens=args.max_tokens, stop=stop_words) \
         if args.decode=="sample" else SamplingParams(temperature=0.0, max_tokens=args.max_tokens, stop=stop_words)

    fout=open(os.path.join(out_dir,f"partial_{pid}.jsonl"),"w",encoding="utf-8")

    while True:
        try: batch=tq.get(timeout=10)
        except: break
        outputs=llm.generate([build_prompt(e,tok,args.system_prompt) for e in batch], sp, use_tqdm=False)
        for ex,out in zip(batch,outputs):
            text=out.outputs[0].text
            pred=extract_prediction(text)
            fout.write(json.dumps({
                "category":ex["category"],
                "question":ex["question"],
                "choices":ex["choices"],
                "label":ex["label"],
                "prediction":pred,
                "stop_reason":out.outputs[0].stop_reason,
                "response": text                 # NEW: ±£´æÍêÕûÄ£ÐÍÊä³ö
            }, ensure_ascii=False)+"\n")
            fout.flush()
            pq.put(1)
    fout.close()

# --------------------------- merge & metrics ---------------------------------
def merge_results(part_dir, final_path):
    res=[]
    for fn in os.listdir(part_dir):
        if fn.startswith("partial_") and fn.endswith(".jsonl"):
            with open(os.path.join(part_dir,fn),"r",encoding="utf-8") as f:
                res.extend(json.loads(l) for l in f)
    with open(final_path,"w",encoding="utf-8") as f:
        for r in res: f.write(json.dumps(r,ensure_ascii=False)+"\n")
    return res

def evaluate_results(res):
    per={c:[0,0] for c in CATEGORIES}
    for r in res:
        per[r["category"]][0 if r["prediction"]==r["label"] else 1]+=1
    total_ok=sum(v[0] for v in per.values()); total=sum(sum(v) for v in per.values())
    print("----- Per-category accuracy -----")
    for c,(ok,ng) in per.items():
        acc=ok/(ok+ng)*100 if ok+ng else 0
        print(f"{c:20s}: {acc:6.2f}% ({ok}/{ok+ng})")
    print("---------------------------------")
    print(f"Overall Accuracy: {total_ok/total*100:.2f}% ({total_ok}/{total})")

# -------------------------------- main --------------------------------------
def main():
    pa=argparse.ArgumentParser()
    pa.add_argument("--model_path",required=True)
    pa.add_argument("--output_dir",required=True)
    pa.add_argument("--decode",choices=["greedy","sample"],default="greedy")
    pa.add_argument("--system_prompt",choices=["none","qwen","deepseek"],default="none")
    pa.add_argument("--max_tokens",type=int,default=8192)
    pa.add_argument("--batch_size",type=int,default=16)
    pa.add_argument("--tensor_parallel",type=int,default=1)
    pa.add_argument("--n_repeats",type=int,default=1)
    args=pa.parse_args()

    os.makedirs(args.output_dir,exist_ok=True)

    # load & prep dataset
    ds=datasets.load_dataset("m-a-p/SuperGPQA",split="train")
    rng=random.Random(42); examples=[]
    for e in tqdm(ds,desc="Preparing"):
        perm=rng.sample(range(len(e["options"])),len(e["options"]))
        choices=[e["options"][i] for i in perm]
        correct_letter=OPTIONS_LETTERS[perm.index(OPTIONS_LETTERS.index(e["answer_letter"]))]
        examples.append({"category":e["discipline"],"question":e["question"],
                         "choices":choices,"label":correct_letter})
    examples*=args.n_repeats
    num=len(examples); print(f"Total examples: {num}")

    # mp setup
    gpus=torch.cuda.device_count()
    if gpus<args.tensor_parallel: raise ValueError("Not enough GPUs.")
    procs=gpus//args.tensor_parallel
    print(f"Launching {procs} processes¡­")

    bs=min(args.batch_size,max(1,num//procs))
    part_dir=os.path.join(args.output_dir,"partials")
    if os.path.isdir(part_dir): shutil.rmtree(part_dir)
    os.makedirs(part_dir,exist_ok=True)

    mgr=mp.Manager(); tq=mgr.Queue(); pq=mgr.Queue()
    for i in range(0,num,bs): tq.put(examples[i:i+bs])

    ps=[]
    for pid in range(procs):
        p=mp.Process(target=worker_process,args=(pid,tq,pq,args,part_dir)); p.start(); ps.append(p)

    tot_batches=(num+bs-1)//bs
    with tqdm(total=tot_batches,desc="Super GPQA Batches") as bar:
        done=0
        while done<tot_batches:
            try: done+=pq.get(timeout=1); bar.update(1)
            except: pass
    for p in ps: p.join()

    final_path=os.path.join(args.output_dir,f"SuperGPQA-L{args.max_tokens}-D{args.decode}.jsonl")
    results=merge_results(part_dir,final_path)
    evaluate_results(results)               # NEW: ÆÀ¹ÀÈÔÕý³££¬Òò prediction ×Ö¶ÎÎ´±ä
    print(f"\nRaw model outputs saved in: {final_path}")  # NEW: ÌáÊ¾ÎÄ¼þÂ·¾¶

if __name__=="__main__":
    main()
