# evaluate_benchmarks.py
# Evaluate a fine-tuned Qwen2.5-32B model on MATH500, AIME24, AIME25, GPQA-Diamond,
# BIG Bench Extra Hard, Big Math, BRAINTEASER, Explore TOM, and MMLU-Pro
# using a custom generation function: group_think_cross_attend_generate

import os
import json
import time
import random
import argparse
import re
from sympy import GoldenRatio
import yaml
from typing import List, Dict, Optional, Any, Tuple
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Import your custom generation function
from inference_text_infilling import group_think_cross_attend_generate
from inference_simulation import group_think_simulation_generate
from inference_text_infilling_final_optimized import group_think_cross_attend_generate_final_optimized


################################################################################
# Prompt templates (standard ones commonly used in literature)
################################################################################

SYSTEM_PROMPT_DEFAULT = (
    "You are a helpful, precise, and rigorous assistant for solving problems."
)

################################################################################
# Small helpers (decoding + streaming + selection)
################################################################################

def _decode_group_think_output(out, tokenizer) -> Tuple[str, Optional[Dict[str, str]]]:
    """Return (joined_text, group_traces or None) from generator output."""
    # Case GT Simulation
    if isinstance(out, torch.Tensor):
        return tokenizer.decode(out[0], skip_special_tokens=True).strip(), None
    # Case: legacy dict with text
    if isinstance(out, dict) and "text" in out:
        return str(out["text"]), None
    # Case: list of path states with .ids
    if isinstance(out, list):
        per_path_texts = []
        for p in out:
            if hasattr(p, "ids"):
                try:
                    per_path_texts.append(
                        tokenizer.decode(p.ids[0], skip_special_tokens=True).strip()
                    )
                except Exception:
                    per_path_texts.append("")
        if per_path_texts:
            joined = "\n".join(per_path_texts)
            group_traces = {str(i): t for i, t in enumerate(per_path_texts)}
            return joined, group_traces
    # Fallback
    return str(out), None


def _build_aligned_row(row: Dict[str, Any]) -> Dict[str, Any]:
    golden_answer = row.get("gold").get('golden_answer')
    return {
        "dataset_name": f"{row['task']}_{row['index']}",
        # "question": row.get("question") or row.get("problem"),
        "question": row.get("usr_prompt") or row.get("question") or row.get("problem"),
        "answer": golden_answer,
        "answer_info": row.get("gold"),
        "group_traces": row.get("group_traces", {}),
        "evaluations": {},
        "predictions": row.get("predictions", []),
        "metric": {"correct_at_k": row.get("correct_at_k"), "k": row.get("k")},
    }


def _stream_write_jsonl(aligned_path: str, obj: Dict[str, Any]):
    try:
        with open(aligned_path, "a", encoding="utf-8") as af:
            af.write(json.dumps(obj, ensure_ascii=False) + "\n")
    except Exception:
        pass


def _stream_write_metric(metrics_path: str, metric_now: Dict[str, Any]):
    try:
        with open(metrics_path, "w", encoding="utf-8") as mf:
            json.dump(metric_now, mf, ensure_ascii=False, indent=2)
    except Exception:
        pass


def _select_eval_plan(args, ds_math500, ds_aime24, ds_aime25, ds_gpqa,
                      ds_bbeh, ds_bigmath, ds_brainteaser, ds_explore_tom, ds_mmlu_pro):
    requested = [s.strip().lower() for s in (args.datasets or "all").split(",")]
    all_map = [
        ("MATH", ds_math500, "math500"),
        ("AIME", ds_aime24, "aime24"),
        ("AIME", ds_aime25, "aime25"),
        ("GPQA", ds_gpqa, "gpqa_diamond"),
        ("BBEH", ds_bbeh, "bb_extra_hard"),
        ("BIGMATH", ds_bigmath, "bigmath"),
        ("BRAINTEASER", ds_brainteaser, "brainteaser"),
        ("EXPLORETOM", ds_explore_tom, "explore_tom"),
        ("MMLU_PRO", ds_mmlu_pro, "mmlu_pro"),
    ]
    if "all" in requested:
        return all_map
    name_to_item = {
        "math500": ("MATH", ds_math500, "math500"),
        "aime24": ("AIME", ds_aime24, "aime24"),
        "aime25": ("AIME", ds_aime25, "aime25"),
        "gpqa_diamond": ("GPQA", ds_gpqa, "gpqa_diamond"),
        "bb_extra_hard": ("BBEH", ds_bbeh, "bb_extra_hard"),
        "bigmath": ("BIGMATH", ds_bigmath, "bigmath"),
        "brainteaser": ("BRAINTEASER", ds_brainteaser, "brainteaser"),
        "explore_tom": ("EXPLORETOM", ds_explore_tom, "explore_tom"),
        "mmlu_pro": ("MMLU_PRO", ds_mmlu_pro, "mmlu_pro"),
    }
    plan = []
    for key in requested:
        if key in name_to_item:
            plan.append(name_to_item[key])
    return plan

def build_prompt_aime(problem: str) -> Tuple[str, str]:
    user_prompt = (
        "You are given a problem from the American Invitational Mathematics Examination (AIME).\n"
        "Solve the problem carefully. The final answer is a single integer between 0 and 999.\n"
        "Show your reasoning, and put only the final answer in the form \\boxed{<answer>} at the end.\n\n"
        f"Problem:\n{problem}\n"
    )
    return SYSTEM_PROMPT_DEFAULT, user_prompt

def build_prompt_math(problem: str) -> Tuple[str, str]:
    user_prompt = (
        "Solve the following math problem. Show your reasoning, and put your final answer in the form "
        "\\boxed{<answer>} at the end.\n\n"
        f"Problem:\n{problem}\n"
    )
    return SYSTEM_PROMPT_DEFAULT, user_prompt

def build_prompt_gpqa(question: str, options: List[str]) -> Tuple[str, str]:
    lab_opts = []
    for i, opt in enumerate(options[:4]):
        label = chr(65 + i)
        lab_opts.append(f"{label}. {opt}")
    options_block = "\n".join(lab_opts)
    user_prompt = (
        "You will be given a graduate-level multiple-choice science question with four options.\n"
        "Choose the single best answer. Respond with only the letter A, B, C, or D, and put it in "
        "\\boxed{<letter>} at the end.\n\n"
        f"Question:\n{question}\n\nOptions:\n{options_block}\n"
    )
    return SYSTEM_PROMPT_DEFAULT, user_prompt

# Generic MCQ prompt (A-?)
def build_prompt_mc_generic(question: str, options: List[str]) -> Tuple[str, str]:
    lab_opts = []
    letters = [chr(65 + i) for i in range(min(len(options), 26))]
    for i, opt in enumerate(options[:26]):
        lab_opts.append(f"{letters[i]}. {opt}")
    letters_str = ", ".join(letters)
    options_block = "\n".join(lab_opts)
    user_prompt = (
        "You will be given a multiple-choice question.\n"
        f"Choose the single best answer. Respond with only one of the letters {letters_str}, "
        "and put it in \\boxed{<letter>} at the end.\n\n"
        f"Question:\n{question}\n\nOptions:\n{options_block}\n"
    )
    return SYSTEM_PROMPT_DEFAULT, user_prompt

# Yes/No prompt
def build_prompt_yesno(question: str) -> Tuple[str, str]:
    user_prompt = (
        "Answer the following question with Yes or No.\n"
        "Show your reasoning briefly, and put only the final answer in the form \\boxed{Yes} or \\boxed{No} at the end.\n\n"
        f"Question:\n{question}\n"
    )
    return SYSTEM_PROMPT_DEFAULT, user_prompt


################################################################################
# Answer extraction and normalization
################################################################################

BOXED_RE = re.compile(r"\\boxed\\{([^{}]+)\\}")

def extract_last_boxed(text: str) -> Optional[str]:
    if not text:
        return None
    matches = BOXED_RE.findall(text)
    if matches:
        return matches[-1].strip()
    return None

def extract_aime_answer(text: str) -> Optional[int]:
    ans = extract_last_boxed(text)
    candidates = []
    if ans is not None:
        candidates.append(ans)
    tail_digits = re.findall(r"(\d{1,3})", text or "")
    if tail_digits:
        candidates.append(tail_digits[-1])
    match = re.search(r"(?i)answer is[^0-9]*?(\d{1,3})\b", text or "")
    if match:
        candidates.append(match.group(1))
    for cand in candidates:
        try:
            val = int(str(cand).strip())
            if 0 <= val <= 999:
                return val
        except Exception:
            continue
    return None

def normalize_latex_string(s: str) -> str:
    if s is None:
        return ""
    s = s.strip()
    m = BOXED_RE.search(s)
    if m:
        s = m.group(1)

    s = re.sub(r"\\\\left", "", s)
    s = re.sub(r"\\right", "", s)
    s = re.sub(r"\\,", "", s)
    s = re.sub(r"\\!", "", s)
    s = re.sub(r"\$", "", s)
    s = re.sub(r"\\%", "%", s)
    s = re.sub(r"\\\\cdot", "*", s)
    s = re.sub(r"\\times", "*", s)
    s = re.sub(r"\\div", "/", s)
    s = re.sub(r"\\text\{([^}]*)\}", r"\1", s)
    s = re.sub(r"\\\(", "", s)
    s = re.sub(r"\\\)", "", s)
    s = re.sub(r"\\\[", "", s)
    s = re.sub(r"\\\]", "", s)
    s = re.sub(r"\.$", "", s)
    s = re.sub(r"\s+", "", s)
    return s

def extract_math_answer_string(text: str) -> Optional[str]:
    boxed = extract_last_boxed(text)
    if boxed is not None:
        return normalize_latex_string(boxed)
    m = re.search(r"(?i)answer\s*[:\-]\s*([^\n]+)$", text or "")
    if m:
        return normalize_latex_string(m.group(1))
    lines = [ln.strip() for ln in (text or "").splitlines() if ln.strip()]
    if lines:
        return normalize_latex_string(lines[-1])
    return None

def is_math_equal(pred: str, gold: str) -> bool:
    p = normalize_latex_string(pred)
    g = normalize_latex_string(gold)
    if p == g:
        return True
    try:
        import sympy as sp
        try:
            from sympy.parsing.latex import parse_latex as sympy_parse_latex
            p_expr = sympy_parse_latex(p)
            g_expr = sympy_parse_latex(g)
        except Exception:
            def to_sympy(s: str) -> str:
                s2 = s
                s2 = s2.replace("^", "**")
                return s2
            p_expr = sp.sympify(to_sympy(p))
            g_expr = sp.sympify(to_sympy(g))
        if p_expr.equals(g_expr):
            return True
        free = list((p_expr.free_symbols | g_expr.free_symbols))
        if len(free) == 0:
            return sp.simplify(p_expr - g_expr) == 0
        return False
    except Exception:
        return False

def extract_choice_letter_general(text: str, allowed_letters: List[str]) -> Optional[str]:
    boxed = extract_last_boxed(text)
    if boxed:
        letter = boxed.strip().upper()
        if letter in allowed_letters:
            return letter
    pattern = r"\b([" + "".join(allowed_letters) + r"])\b"
    matches = re.findall(pattern, (text or "").upper())
    if matches:
        return matches[-1]
    m = re.search(r"(?i)answer\s*[:\-]\s*([A-Z])\b", text or "")
    if m:
        cand = m.group(1).upper()
        if cand in allowed_letters:
            return cand
    return None

def extract_yesno(text: str) -> Optional[str]:
    boxed = extract_last_boxed(text)
    if boxed:
        b = boxed.strip().lower()
        if b in ["yes", "no"]:
            return "YES" if b == "yes" else "NO"
    tail = (text or "").strip().lower()
    matches = re.findall(r"\b(yes|no)\b", tail)
    if matches:
        return "YES" if matches[-1].lower() == "yes" else "NO"
    m = re.search(r"(?i)answer\s*[:\-]\s*(yes|no)\b", text or "")
    if m:
        return "YES" if m.group(1).lower() == "yes" else "NO"
    return None

def normalize_short_answer(s: str) -> str:
    if s is None:
        return ""
    s = str(s).strip().lower()
    s = re.sub(r"[\"'`]+", "", s)
    s = re.sub(r"[^a-z0-9]+", " ", s)
    tokens = [t for t in s.split() if t not in {"a", "an", "the"}]
    return " ".join(tokens)

def extract_freeform_text_answer(text: str) -> Optional[str]:
    boxed = extract_last_boxed(text)
    if boxed:
        return normalize_short_answer(boxed)
    m = re.search(r"(?i)answer\s*[:\-]\s*([^\n]+)$", text or "")
    if m:
        return normalize_short_answer(m.group(1))
    lines = [ln.strip() for ln in (text or "").splitlines() if ln.strip()]
    if lines:
        return normalize_short_answer(lines[-1])
    return None


################################################################################
# Dataset loading (from local disk)
################################################################################

def load_local_dataset(path: str):
    from datasets import load_from_disk, load_dataset
    if os.path.isdir(path):
        try:
            ds = load_from_disk(path)
            if hasattr(ds, "keys"):
                for split in ["test", "validation", "val", "diamond", "dev", "train"]:
                    if split in ds:
                        return ds[split]
                first_key = next(iter(ds.keys()))
                return ds[first_key]
            return ds
        except Exception:
            jsonls = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jsonl") or f.endswith(".json")]
            if not jsonls:
                raise RuntimeError(f"Could not load dataset from {path}. Not a saved dataset or JSON/JSONL directory.")
            ds = load_dataset("json", data_files=jsonls, split="train")
            return ds
    else:
        if path.endswith(".json") or path.endswith(".jsonl"):
            ds = load_dataset("json", data_files=path, split="train")
            return ds
        raise RuntimeError(f"Unsupported dataset path: {path}. Provide a saved dataset directory or a JSON/JSONL file.")


################################################################################
# Flexible field mapping for benchmarks
################################################################################

def get_field(example: Dict[str, Any], candidates: List[str], default=None):
    for k in candidates:
        if k in example and example[k] is not None:
            return example[k]
    return default

def detect_math_problem_and_answer(example: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
    problem = get_field(example, ["problem", "question", "prompt", "input", "query"])
    gold = get_field(example, ["answer", "final_answer", "target", "solution"])
    return problem, gold

def detect_aime_problem_and_answer(example: Dict[str, Any]) -> Tuple[Optional[str], Optional[int]]:
    problem = get_field(example, ["problem", "question", "prompt", "input", "query"])
    gold = get_field(example, ["answer", "final_answer", "target"])
    gold_int = None
    try:
        if isinstance(gold, str):
            gold_int = int(re.sub(r"[^\d]", "", gold))
        elif isinstance(gold, (int, float)):
            gold_int = int(gold)
    except Exception:
        gold_int = None
    return problem, gold_int

def extract_mc_options_from_text(text: str) -> List[str]:
    """Extract up to J options from question text."""
    if not text:
        return []
    patterns = [
        r'[\\(\\[]([A-Ja-j])[\\)\\]]\s*([^\n]+)',  # (A) option or [A] option
        r'([A-Ja-j])\.\s*([^\n]+)',          # A. option
        r'([A-Ja-j])\)\s*([^\n]+)',          # A) option
        r'\\(([A-Ja-j])\\)',
        r'\[([A-Ja-j])\]',
        r'([A-Ja-j])\.',
        r'([A-Ja-j])\)',
    ]
    for pattern in patterns:
        matches = re.findall(pattern, text)
        if len(matches) >= 2:
            sorted_matches = sorted(matches, key=lambda x: x[0].upper())
            # Group and keep in alpha order A..J without duplicates
            seen = set()
            opts = []
            for lbl, opt in sorted_matches:
                lbl_u = lbl.upper()
                if lbl_u not in seen:
                    seen.add(lbl_u)
                    opts.append(opt.strip())
            # Return contiguous from A upward if possible
            return opts
    return []

def detect_gpqa_fields(example: Dict[str, Any]) -> Tuple[Optional[str], List[str], Optional[str]]:
    question = get_field(example, ["question", "problem", "prompt", "input", "query"])
    options = get_field(example, ["choices", "options"])
    if options and isinstance(options, list) and options and isinstance(options[0], dict):
        # try map dict choices to text
        new_opts = []
        for c in options:
            if isinstance(c, dict):
                if "text" in c:
                    new_opts.append(c["text"])
                elif "label" in c:
                    new_opts.append(c["label"])
        if len(new_opts) >= 2:
            options = new_opts
    if not options:
        options = [example.get(key) for key in ["A", "B", "C", "D"] if example.get(key)]
        if len(options) != 4:
            options = None
    if not options and question:
        options = extract_mc_options_from_text(question)
    answer = get_field(example, ["answer", "correct", "label", "gold", "answerKey"])
    if isinstance(answer, str):
        answer = answer.strip().upper()
        if answer in ["A", "B", "C", "D"]:
            return question, options or [], answer
    ans_idx = get_field(example, ["answer_idx", "answer_index", "label_idx", "gold_idx"])
    if ans_idx is not None:
        try:
            idx = int(ans_idx)
            if 0 <= idx < 4:
                return question, options or [], chr(65 + idx)
        except (ValueError, TypeError):
            pass
    return question, options or [], None

# Generic MCQA detector (A-J)
def detect_mcqa_fields_generic(example: Dict[str, Any]) -> Tuple[Optional[str], List[str], Optional[str], List[str]]:
    question = get_field(example, ["question", "problem", "prompt", "input", "query", "context"])
    options = get_field(example, ["choices", "options", "answers", "candidates"])
    # Normalize options if they are dicts
    if options and isinstance(options, list):
        tmp = []
        for it in options:
            if isinstance(it, dict):
                if "text" in it:
                    tmp.append(it["text"])
                elif "label" in it:
                    tmp.append(it["label"])
                elif "option" in it:
                    tmp.append(it["option"])
            elif isinstance(it, str):
                tmp.append(it)
        if tmp:
            options = tmp
    # If options are not a list, try individual A..J keys
    if not (isinstance(options, list) and len(options) >= 2):
        opts = []
        for i in range(10):
            key = chr(65 + i)
            v = example.get(key) or example.get(key.lower())
            if v:
                opts.append(v)
        options = opts if len(opts) >= 2 else None
    # Fallback to parse from question text
    if not options and question:
        options = extract_mc_options_from_text(question)
    options = options or []
    allowed_letters = [chr(65 + i) for i in range(min(len(options), 26))] if options else list("ABCD")
    # Gold detection
    ans = get_field(example, ["answer", "correct", "label", "gold", "answerKey", "final_answer", "target"])
    gold_letter = None
    # Case: explicit letter
    if isinstance(ans, str):
        s = ans.strip()
        sU = s.upper()
        if sU in allowed_letters:
            gold_letter = sU
        # sometimes "Option C"
        m = re.search(r"\b([A-Z])\b", sU)
        if not gold_letter and m and m.group(1) in allowed_letters:
            gold_letter = m.group(1)
    # Case: index
    if gold_letter is None and ans is not None:
        try:
            idx = int(ans)
            if 0 <= idx < len(allowed_letters):
                gold_letter = allowed_letters[idx]
        except Exception:
            pass
    # Case: full option text
    if gold_letter is None and isinstance(ans, str) and options:
        norm = normalize_short_answer(ans)
        for i, opt in enumerate(options):
            if normalize_short_answer(opt) == norm:
                gold_letter = allowed_letters[i]
                break
    # Case: dict answer
    if gold_letter is None and isinstance(ans, dict):
        if "label" in ans and str(ans["label"]).upper() in allowed_letters:
            gold_letter = str(ans["label"]).upper()
        elif "index" in ans:
            try:
                idx = int(ans["index"])
                if 0 <= idx < len(allowed_letters):
                    gold_letter = allowed_letters[idx]
            except Exception:
                pass
        elif "text" in ans and options:
            norm = normalize_short_answer(ans["text"])
            for i, opt in enumerate(options):
                if normalize_short_answer(opt) == norm:
                    gold_letter = allowed_letters[i]
                    break
    return question, options, gold_letter, allowed_letters

# Yes/No detector
def detect_yesno_fields(example: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
    question = get_field(example, ["question", "problem", "prompt", "input", "query", "context"])
    ans = get_field(example, ["answer", "label", "gold", "target", "correct"])
    gold = None
    if isinstance(ans, str):
        s = ans.strip().lower()
        if s in ["yes", "true", "y", "1"]:
            gold = "YES"
        elif s in ["no", "false", "n", "0"]:
            gold = "NO"
    elif isinstance(ans, (int, float, bool)):
        v = int(ans)
        gold = "YES" if v == 1 or ans is True else "NO"
    return question, gold


################################################################################
# Model loading (base + finetuned weights)
################################################################################

def load_model_and_tokenizer(
    model_path: str,
    dtype: str = "bfloat16",
    device_map: str = "auto",
    trust_remote_code: bool = True,
):
    torch_dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=trust_remote_code)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch_dtype,
            device_map=device_map,
            trust_remote_code=trust_remote_code,
        )
    except Exception as e:
        raise RuntimeError(f"Could not load model from {model_path}: {e}")

    if hasattr(model, "generation_config"):
        gen_cfg = GenerationConfig.from_model_config(model.config)
        gen_cfg.temperature = 0.7
        gen_cfg.top_p = 0.95
        gen_cfg.do_sample = True
        model.generation_config = gen_cfg

    model.eval()
    return model, tokenizer


################################################################################
# Evaluation core (pass@k)
################################################################################

def set_random_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def pass_at_k_from_predictions(preds: List[Any], gold: Any, task: str) -> bool:
    if task == "AIME":
        for pred_txt in preds:
            pred_int = extract_aime_answer(pred_txt)
            if pred_int is not None and gold is not None and int(pred_int) == int(gold):
                return True
        return False
    elif task == "MATH":
        if gold is None:
            return False
        for pred_txt in preds:
            pred_str = extract_math_answer_string(pred_txt) or ""
            if is_math_equal(pred_str, str(gold)):
                return True
        return False
    elif task == "GPQA":
        for pred_txt in preds:
            letter = extract_gpqa_choice_letter(pred_txt)
            if letter is not None and gold is not None and letter.strip().upper() == str(gold).strip().upper():
                return True
        return False
    else:
        raise ValueError(f"Unknown task: {task}")

def extract_gpqa_choice_letter(text: str) -> Optional[str]:
    boxed = extract_last_boxed(text)
    if boxed:
        letter = boxed.strip().upper()
        if letter in ["A", "B", "C", "D"]:
            return letter
    matches = re.findall(r"\b([ABCD])\b", (text or "").upper())
    if matches:
        return matches[-1]
    m = re.search(r"(?i)answer\s*[:\-]\s*([ABCD])\b", text or "")
    if m:
        return m.group(1).upper()
    return None

def evaluate_dataset(
    model,
    tokenizer,
    dataset,
    task_name: str,
    k: int,
    seeds: List[int],
    num_paths: int,
    shift: int,
    max_path_tokens: int,
    system_prompt_override: Optional[str] = None,
    verbose: bool = False,
    aligned_path: Optional[str] = None,
    metrics_path: Optional[str] = None,
    stream: bool = False,
    max_samples: int = -1,
    simulated_gt: bool = False,
    use_gt_opt_func: bool = False
) -> Dict[str, Any]:
    _map_idx_to_ch = {i: c for i, c in enumerate("ABCDEFGHIJKLMNOPQRSTUVWXYZ")}
    results = []
    correct_count = 0
    processed = 0
    total = len(dataset)
    if max_samples > 0:
        total = min(total, max_samples)

    for idx in tqdm(range(total), total=total, desc=f"Eval {task_name}"):
        ex = dataset[idx]

        # Build prompts and golds
        options = None
        allowed_letters = None
        if task_name == "AIME":
            problem, gold = detect_aime_problem_and_answer(ex)
            if not problem or gold is None:
                continue
            sys_prompt, usr_prompt = build_prompt_aime(problem)
            gold_norm = int(gold)
            query_text = problem

        elif task_name == "MATH":
            problem, gold = detect_math_problem_and_answer(ex)
            if not problem or gold is None:
                continue
            sys_prompt, usr_prompt = build_prompt_math(problem)
            gold_norm = str(gold)
            query_text = problem

        elif task_name == "GPQA":
            question, opts, gold_letter = detect_gpqa_fields(ex)
            if (not question) or (not opts) or (gold_letter is None):
                continue
            sys_prompt, usr_prompt = build_prompt_gpqa(question, opts)
            gold_norm = gold_letter
            query_text = question
            options = opts

        elif task_name == "BBEH":
            question, opts, gold_letter, allowed_letters = detect_mcqa_fields_generic(ex)
            if not question:
                continue
            if opts and gold_letter is not None:
                sys_prompt, usr_prompt = build_prompt_mc_generic(question, opts)
                gold_norm = {"gold_letter": gold_letter, "allowed": allowed_letters}
                options = opts
            else:
                # Free-form fallback
                sys_prompt, usr_prompt = build_prompt_math(question)  # generic reasoning + boxed fallback
                gold_free = get_field(ex, ["answer", "target", "gold", "final_answer", "label"])
                gold_norm = {"freeform": normalize_short_answer(gold_free) if gold_free else None}
            query_text = question

        elif task_name == "BIGMATH":
            problem, gold = detect_math_problem_and_answer(ex)
            if not problem or gold is None:
                continue
            sys_prompt, usr_prompt = build_prompt_math(problem)
            gold_norm = str(gold)
            query_text = problem

        elif task_name == "BRAINTEASER":
            # Raw HF dataset format: ex['question']['stem'], ex['question']['choices'], ex['answerKey']
            question = ex['question']['stem']
            choices = ex['question']['choices']
            answer_key = ex['answerKey']
            choices_str = '\n'.join([f"{c['label']}: {c['text']}" for c in choices])
            # Extract options from choices
            allowed_letters = []
            for choice in choices:
                allowed_letters.append(choice['label'])
            letters_str = ", ".join(allowed_letters)
            gold_letter = answer_key
            
            usr_prompt = (
                "You will be given a multiple-choice question.\n"
                f"Choose the single best answer. Respond with only one of the letters {letters_str}, "
                "and put it in \\boxed{<letter>} at the end.\n\n"
                f"Question:\n{question}\n\nAnswer Options:\n{choices_str}\n"
            )
            sys_prompt = SYSTEM_PROMPT_DEFAULT
            gold_norm = {"gold_letter": gold_letter, "allowed": allowed_letters, "golden_answer": gold_letter}
            options = choices
            query_text = question

        elif task_name == "EXPLORETOM":
            problem = ex["question"]
            story_structure = ex["story_structure"]
            infilled_story = ex["infilled_story"]
            query_text=f"Here a story: {infilled_story}\nFrom the story, answer this question: {problem}"
            golden_answer=normalize_short_answer(ex["expected_answer"])
            
            usr_prompt = (
                "You are given a story and a question that can be answered based on the story.\n\n"
                "Provide the single best short answer.\n"
                "Put only the final answer at the end in \\boxed{<answer>} format.\n\n"
                "------"
                f"{query_text}"
            )
            sys_prompt = SYSTEM_PROMPT_DEFAULT
            
            gold_norm = {"freeform": golden_answer, "golden_answer": golden_answer}
            options = None

        elif task_name == "MMLU_PRO":
            question = ex['question']
            
            ans_choices = ex["options"]
            answer_options = ""
            allowed_letters = []
            for ans_idx, ans in enumerate(ans_choices):
                ans_ch = _map_idx_to_ch[ans_idx]
                answer_options += f"{ans_ch}. {ans}\n"
                allowed_letters.append(ans_ch)
            letters_str = ", ".join(allowed_letters)
            
            ans_idx = ex['answer_index']
            gold_letter = _map_idx_to_ch[ans_idx]

            usr_prompt = (
                "You will be given a multiple-choice question.\n"
                f"Choose the single best answer. Respond with only one of the letters {letters_str}, "
                "and put it in \\boxed{<letter>} at the end.\n\n"
                f"Question:\n{question}\n\nAnswer options:\n{answer_options}\n"
            )            
            sys_prompt = SYSTEM_PROMPT_DEFAULT
            gold_norm = {"gold_letter": gold_letter, "allowed": allowed_letters, "golden_answer": gold_letter}
            query_text = question
            options = ans_choices

        else:
            raise ValueError(f"Unknown task: {task_name}")

        system_prompt = system_prompt_override or sys_prompt
        print(f"System prompt:\n{system_prompt}")
        inf_func = group_think_cross_attend_generate
        if simulated_gt:
            inf_func = group_think_simulation_generate
        elif use_gt_opt_func:
            inf_func = group_think_cross_attend_generate_final_optimized

        # Collect k predictions
        pred_texts = []
        group_traces_first_sample = None
        for j in range(k):
            seed = seeds[j] if j < len(seeds) else random.randint(0, 10**9)
            set_random_seed(seed)
            try:
                out = inf_func(
                    model=model,
                    tokenizer=tokenizer,
                    prompt=usr_prompt,
                    system_prompt=system_prompt,
                    num_paths=num_paths,
                    shift=shift + j * 17,
                    max_path_tokens=max_path_tokens,
                    verbose=False,
                    step_callback=None,
                )
                text, gt = _decode_group_think_output(out, tokenizer)
                if j == 0 and gt is not None:
                    group_traces_first_sample = gt
            except Exception as e:
                text = f"<<GENERATION_ERROR: {e}>>"
            pred_texts.append(text)

        # Compute correctness
        if task_name in ["AIME", "MATH", "GPQA"]:
            is_correct = pass_at_k_from_predictions(pred_texts, gold_norm, task_name)
        elif task_name in ["BIGMATH"]:
            is_correct = False
            for pred_txt in pred_texts:
                pred_str = extract_math_answer_string(pred_txt) or ""
                if is_math_equal(pred_str, str(gold_norm)):
                    is_correct = True
                    break
        elif task_name in ["EXPLORETOM", "BBEH", "MMLU_PRO","BRAINTEASER"]:
            # Two modes: MC or Free-form fallback
            if isinstance(gold_norm, dict) and "gold_letter" in gold_norm:
                allowed = gold_norm["allowed"]
                gold_letter = gold_norm["gold_letter"]
                is_correct = False
                for pred_txt in pred_texts:
                    letter = extract_choice_letter_general(pred_txt, allowed)
                    if letter is not None and letter == gold_letter:
                        is_correct = True
                        break
            else:
                gold_free = (gold_norm or {}).get("freeform")
                is_correct = False
                if gold_free:
                    for pred_txt in pred_texts:
                        pred = extract_freeform_text_answer(pred_txt)
                        if pred and pred == gold_free:
                            is_correct = True
                            break
                # If no gold, mark incorrect
        else:
            is_correct = False

        correct_count += int(is_correct)
        processed += 1

        row = {
            "index": idx,
            "task": task_name,
            "gold": gold_norm,
            "predictions": pred_texts,
            "correct_at_k": bool(is_correct),
            "k": k,
            "usr_prompt": usr_prompt
        }
        if task_name in ["AIME", "MATH", "BIGMATH"]:
            row["problem"] = query_text
            row["question"] = query_text
        else:
            row["question"] = query_text
            if options:
                row["options"] = options
        if group_traces_first_sample is not None:
            row["group_traces"] = group_traces_first_sample
        results.append(row)

        # Stream aligned JSONL and iterative metrics
        if stream and aligned_path is not None:
            _stream_write_jsonl(aligned_path, _build_aligned_row(row))

        if stream and metrics_path is not None:
            n_eval_so_far = len(results)
            metric_now = {
                "task": task_name,
                "num_examples": n_eval_so_far,
                "pass_at_k": correct_count / n_eval_so_far if n_eval_so_far > 0 else 0.0,
                "correct": correct_count,
                "k": k,
            }
            _stream_write_metric(metrics_path, metric_now)

        if verbose and processed % 25 == 0:
            print(f"[{task_name}] Processed {processed}/{total} ... Running pass@{k}: {correct_count/processed:.4f}")

    n_eval = len(results)
    metric = {
        "task": task_name,
        "num_examples": n_eval,
        "pass_at_k": correct_count / n_eval if n_eval > 0 else 0.0,
        "correct": correct_count,
        "k": k,
    }
    return {
        "metric": metric,
        "details": results,
    }


################################################################################
# Hyperparameter saving
################################################################################

def save_hyperparameters_report(args, output_dir: str):
    report = {
        "evaluation_config": {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "model_path": args.model_path,
            "dtype": args.dtype,
            "device_map": args.device_map,
            "system_prompt_override": args.system_prompt,
            "verbose": args.verbose
        },
        "dataset_paths": {
            "math500_path": args.math500_path,
            "aime24_path": args.aime24_path,
            "aime25_path": args.aime25_path,
            "gpqa_diamond_path": args.gpqa_diamond_path,
            "bb_extra_hard_path": args.bb_extra_hard_path,
            "bigmath_path": args.bigmath_path,
            "brainteaser_path": args.brainteaser_path,
            "explore_tom_path": args.explore_tom_path,
            "mmlu_pro_path": args.mmlu_pro_path
        },
        "evaluation_params": {
            "k": args.k,
            "seeds": args.seeds,
            "num_paths": args.num_paths,
            "shift": args.shift,
            "max_path_tokens": args.max_path_tokens,
            "max_samples": args.max_samples
        },
        "group_think_config": {
            "num_paths": args.num_paths,
            "shift": args.shift,
            "max_path_tokens": args.max_path_tokens,
            "generation_function": "group_think_cross_attend_generate"
        }
    }
    report_path = os.path.join(output_dir, "report.yaml")
    with open(report_path, "w", encoding="utf-8") as f:
        yaml.dump(report, f, default_flow_style=False, indent=2, allow_unicode=True)
    print(f"Hyperparameters saved to: {report_path}")
    return report_path


################################################################################
# CLI and main
################################################################################

def parse_args():
    ap = argparse.ArgumentParser(description="Evaluate Qwen2.5-32B fine-tuned model on multiple benchmarks with pass@k.")
    ap.add_argument("--model_path", type=str, required=True, help="Model path - base model or finetuned checkpoint directory.")

    ap.add_argument("--math500_path", type=str, required=True, help="Local path to MATH500 dataset (saved dataset dir or JSON/JSONL).")
    ap.add_argument("--aime24_path", type=str, required=True, help="Local path to AIME 2024 dataset.")
    ap.add_argument("--aime25_path", type=str, required=True, help="Local path to AIME 2025 dataset.")
    ap.add_argument("--gpqa_diamond_path", type=str, required=True, help="Local path to GPQA-Diamond dataset.")

    # New datasets
    ap.add_argument("--bb_extra_hard_path", type=str, required=True, help="Local path to BIG Bench Extra Hard dataset.")
    ap.add_argument("--bigmath_path", type=str, required=True, help="Local path to Big Math dataset.")
    ap.add_argument("--brainteaser_path", type=str, required=True, help="Local path to BRAINTEASER dataset.")
    ap.add_argument("--explore_tom_path", type=str, required=True, help="Local path to Explore TOM dataset.")
    ap.add_argument("--mmlu_pro_path", type=str, required=True, help="Local path to MMLU-Pro dataset.")

    ap.add_argument("--k", type=int, default=1, help="pass@k (number of independent generations per example).")
    ap.add_argument("--seeds", type=int, nargs="*", default=[42, 43, 44], help="Seeds for independent generations (length >= k recommended).")

    ap.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16"], help="Model dtype.")
    ap.add_argument("--device_map", type=str, default="auto", help="Transformers device_map (e.g., 'auto').")

    ap.add_argument("--num_paths", type=int, default=4, help="num_paths for group_think_cross_attend_generate.")
    ap.add_argument("--shift", type=int, default=3000, help="shift parameter for group_think_cross_attend_generate.")
    ap.add_argument("--max_path_tokens", type=int, default=256, help="max_path_tokens for group_think_cross_attend_generate.")

    ap.add_argument("--system_prompt", type=str, default=None, help="Override system prompt string (optional).")
    ap.add_argument("--verbose", action="store_true", help="Print progress per dataset.")
    ap.add_argument("--output_dir", type=str, default="eval_outputs", help="Directory to store detailed outputs.")
    ap.add_argument("--datasets", type=str, default="all", help="Comma-separated subset to evaluate: math500,aime24,aime25,gpqa_diamond,bb_extra_hard,bigmath,brainteaser,explore_tom,mmlu_pro or 'all'.")
    ap.add_argument("--max_samples", type=int, default=-1, help="Maximum number of samples to evaluate per dataset. -1 means evaluate all samples.")
    ap.add_argument("--use_gt_optimized_inference", action="store_true", help="Use the optimized group think inference function.")
    return ap.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    print("Saving hyperparameters report...")
    save_hyperparameters_report(args, args.output_dir)

    print("Loading model and tokenizer...")
    model, tokenizer = load_model_and_tokenizer(
        model_path=args.model_path,
        dtype=args.dtype,
        device_map=args.device_map,
    )

    print("Loading datasets from local paths...")
    ds_math500 = load_local_dataset(args.math500_path)
    ds_aime24  = load_local_dataset(args.aime24_path)
    ds_aime25  = load_local_dataset(args.aime25_path)
    ds_gpqa    = load_local_dataset(args.gpqa_diamond_path)
    ds_bbeh    = load_local_dataset(args.bb_extra_hard_path)
    ds_bigmath = load_local_dataset(args.bigmath_path)
    ds_brainteaser = load_local_dataset(args.brainteaser_path)
    ds_explore_tom = load_local_dataset(args.explore_tom_path)
    ds_mmlu_pro = load_local_dataset(args.mmlu_pro_path)

    eval_plan = _select_eval_plan(args, ds_math500, ds_aime24, ds_aime25, ds_gpqa,
                                  ds_bbeh, ds_bigmath, ds_brainteaser, ds_explore_tom, ds_mmlu_pro)

    all_metrics = {}
    summaries = []

    for task_name, ds, outfile in eval_plan:
        print(f"Evaluating {outfile} with pass@{args.k} ...")
        start = time.time()
        aligned_path = os.path.join(args.output_dir, f"{outfile}_group_think_eval.jsonl")
        with open(aligned_path, "w", encoding="utf-8") as _:
            pass
        metrics_path = os.path.join(args.output_dir, f"{outfile}_metrics.json")
        out = evaluate_dataset(
            model=model,
            tokenizer=tokenizer,
            dataset=ds,
            task_name=task_name,
            k=args.k,
            seeds=args.seeds,
            num_paths=args.num_paths,
            shift=args.shift,
            max_path_tokens=args.max_path_tokens,
            system_prompt_override=args.system_prompt,
            verbose=args.verbose,
            aligned_path=aligned_path,
            metrics_path=metrics_path,
            stream=True,
            max_samples=args.max_samples,
            simulated_gt=False,
            use_gt_opt_func=args.use_gt_optimized_inference
        )
        elapsed = time.time() - start
        out_path = os.path.join(args.output_dir, f"{outfile}_details.jsonl")
        with open(out_path, "w", encoding="utf-8") as f:
            for row in out["details"]:
                f.write(json.dumps(row, ensure_ascii=False) + "\n")
        metrics = out["metric"]
        metrics["elapsed_sec"] = elapsed
        metrics_path = os.path.join(args.output_dir, f"{outfile}_metrics.json")
        with open(metrics_path, "w", encoding="utf-8") as f:
            json.dump(metrics, f, ensure_ascii=False, indent=2)
        all_metrics[outfile] = metrics
        summaries.append((outfile, metrics["pass_at_k"], metrics["num_examples"], elapsed))
        print(f"{outfile}: pass@{args.k} = {metrics['pass_at_k']:.4f} ({metrics['correct']}/{metrics['num_examples']}), time={elapsed:.1f}s")

    combined_path = os.path.join(args.output_dir, "summary.json")
    with open(combined_path, "w", encoding="utf-8") as f:
        json.dump(all_metrics, f, ensure_ascii=False, indent=2)

    print("Done. Summary:")
    for name, score, n, t in summaries:
        print(f"- {name}: pass@{args.k}={score:.4f} (N={n}), time={t:.1f}s")


if __name__ == "__main__":
    main()
