#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
test_prob_vllm_clean.py

Features:
- Unified CLI for AIME(2024/2025), LiveCodeBench (LCB), GPQA, MATH-500, Hi-ToM
- Resume-safe (appends only missing completions)
- Per-temperature generation (loop temps)
- One completion per prompt invocation (repeat prompt N times instead of n>1)
- Optional token logprobs + entropy logging
- Optional LCB code execution judging
- Minimal dependencies: vllm, transformers, datasets, tqdm, aiofiles, numpy

Usage examples:
  AIME:
    python test_prob_vllm_clean.py \
      --dataset aime --aime_version 2025 \
      --model Qwen/Qwen2.5-7B-Instruct \
      --temperatures 0.2 0.4 0.6 \
      --num_samples 32 \
      --output_dir results_aime

  LCB (with correctness evaluation):
    python test_prob_vllm_clean.py \
      --dataset lcb --lcb_jsonl lcb_v6_with_prompts.jsonl \
      --model Qwen/Qwen2.5-7B-Instruct \
      --temperatures 0.8 \
      --num_samples 8 \
      --evaluate_lcb \
      --output lcb_temp08.jsonl

  GPQA:
    python test_prob_vllm_clean.py \
      --dataset gpqa --gpqa_jsonl gpqa_dataset.jsonl \
      --model Qwen/Qwen2.5-7B-Instruct \
      --temperatures 0.7 \
      --num_samples 16 \
      --output gpqa_temp07.jsonl
"""

from __future__ import annotations
import os
import re
import json
import math
import argparse
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple, Iterable

import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# Local modules (included in clean package)
from dataset_loader import (
    load_aime_dataset,
    load_lcb_dataset,
    load_gpqa_dataset,
    load_math500_dataset,
    load_hitom_dataset,
    AIMEItem, LCBItem, GPQAItem, Math500Item, HiToMItem
)
from lcb_evaluator import evaluate_lcb_response  # safe to import; optional run

# --------------------------------------------------------------------
# Prompt / answer extraction helpers
# --------------------------------------------------------------------

PROMPT_TMPL_AIME = "Please reason step by step, and put your final answer within \\boxed{}\n\n{Question}"
ANSWER_BOXED_RE = re.compile(r"\\boxed\{([^}]*)\}")

def now_iso() -> str:
    return datetime.now(timezone.utc).isoformat()

def extract_boxed_answer(text: str) -> Optional[str]:
    m = ANSWER_BOXED_RE.findall(text)
    if not m:
        return None
    return m[-1].strip()

INT_RE = re.compile(r"-?\d+")

def normalize_int_str(s: Optional[str]) -> Optional[str]:
    if s is None:
        return None
    s2 = re.sub(r"\s+", "", s)
    s2 = re.sub(r"(?i)(\^\s*\{?\\circ\}?|\\degree|\\deg|°|degrees?)", "", s2)
    return s2

def integer_answer_from_text(text: str) -> Optional[str]:
    # Boxed > fallback last integer
    boxed = extract_boxed_answer(text)
    if boxed:
        # Could itself contain an integer
        ints = INT_RE.findall(boxed)
        if ints:
            return ints[-1]
        return boxed.strip()
    ints = INT_RE.findall(text)
    if ints:
        return ints[-1]
    return None

def is_correct_aime(gold: str, pred: Optional[str]) -> bool:
    if pred is None:
        return False
    return normalize_int_str(gold) == normalize_int_str(pred)

def gpqa_extract_choice(text: str) -> Optional[str]:
    # Look for final letter A-D (case-insensitive)
    m = re.findall(r"\b([A-D])\b", text.upper())
    if not m:
        # Try pattern like (A) or answer: A
        m = re.findall(r"\b([A-D])[\).\s]", text.upper())
    if not m:
        return None
    return m[-1]

def is_correct_gpqa(pred: Optional[str], gold: str) -> bool:
    if pred is None:
        return False
    return pred.strip().upper() == gold.strip().upper()

def simple_norm(s: str) -> str:
    return re.sub(r"\s+", " ", s.strip()).lower()

def is_correct_math(pred: Optional[str], gold: str) -> bool:
    if pred is None:
        return False
    # Try boxed extraction inside pred
    pb = extract_boxed_answer(pred)
    if pb:
        pred = pb
    return simple_norm(pred) == simple_norm(gold)

def is_correct_hitom(pred: Optional[str], gold: str) -> bool:
    if pred is None:
        return False
    return simple_norm(pred) == simple_norm(gold)

def construct_prompt(question_or_prompt: str,
                     model_name: str,
                     dataset_type: str,
                     thinking: bool=False) -> str:
    """
    Minimal prompt builder:
    - AIME uses boxed template
    - Other datasets pass raw prompt
    - Qwen-like models optionally append /no_think when not thinking
    """
    if dataset_type == "aime":
        core = PROMPT_TMPL_AIME.replace("{Question}", question_or_prompt)
    else:
        core = question_or_prompt

    # Simple model family heuristic
    if "Qwen" in model_name:
        user_content = core if thinking else core + "/no_think"
        messages = [{"role": "user", "content": user_content}]
        # Use tokenizer externally (we rely on AutoTokenizer.apply_chat_template)
        # We'll build with a tokenizer later in generation step.
        return messages[0]["content"]  # we rebuild later w/ chat template if needed
    return core

# --------------------------------------------------------------------
# Logprobs & entropy
# --------------------------------------------------------------------

def to_prob(logp: Optional[float]) -> float:
    if logp is None:
        return float("nan")
    try:
        return math.exp(logp)
    except Exception:
        return float("nan")

def pack_token_logprobs_vllm(output) -> Tuple[str, List[Dict[str, Any]]]:
    """
    Extract chosen text + token-level logprobs (top-k) from a vLLM sample.
    """
    # Chosen text
    if hasattr(output, "text"):
        text = output.text
    elif hasattr(output, "outputs") and output.outputs:
        text = output.outputs[0].text
    else:
        text = ""
    text = text.strip()

    tokens_info: List[Dict[str, Any]] = []
    if not hasattr(output, "logprobs") or output.logprobs is None:
        return text, tokens_info

    for position_dict in output.logprobs:
        if not isinstance(position_dict, dict):
            tokens_info.append({
                "token": "",
                "logprob": None,
                "prob": None,
                "top_logprobs": []
            })
            continue
        actual_token = None
        actual_logprob = None
        top_alts = []
        for _, lp_obj in position_dict.items():
            logprob_value = getattr(lp_obj, "logprob", None)
            rank = getattr(lp_obj, "rank", None)
            decoded_token = getattr(lp_obj, "decoded_token", "")
            top_alts.append({
                "token": decoded_token,
                "logprob": logprob_value,
                "prob": to_prob(logprob_value) if logprob_value is not None else None
            })
            if rank == 1:
                actual_token = decoded_token
                actual_logprob = logprob_value
        top_alts.sort(key=lambda x: (x["logprob"] if x["logprob"] is not None else -1e9), reverse=True)
        tokens_info.append({
            "token": actual_token or "",
            "logprob": actual_logprob,
            "prob": to_prob(actual_logprob) if actual_logprob is not None else None,
            "top_logprobs": top_alts
        })
    return text, tokens_info

def calculate_entropy(tokens_info: List[Dict[str, Any]]) -> Tuple[float, List[float]]:
    if not tokens_info:
        return float("nan"), []
    entropies: List[float] = []
    for entry in tokens_info:
        alts = entry.get("top_logprobs", [])
        probs = [a["prob"] for a in alts if isinstance(a.get("prob"), (int, float)) and a["prob"] is not None and a["prob"] >= 0]
        if not probs:
            continue
        s = sum(probs)
        if s <= 0:
            continue
        probs = [p / s for p in probs if p > 0]
        if not probs:
            continue
        ent = -sum(p * math.log(p) for p in probs)
        entropies.append(ent)
    if not entropies:
        return float("nan"), []
    return float(np.mean(entropies)), entropies

# --------------------------------------------------------------------
# Resume logic
# --------------------------------------------------------------------

def count_done_by_key(path: str, dataset: str) -> Dict[Tuple, int]:
    counts: Dict[Tuple, int] = defaultdict(int)
    if not os.path.exists(path):
        return counts
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            if dataset == "aime":
                key = (obj.get("subset"), obj.get("item_id"))
            elif dataset == "lcb":
                key = (obj.get("question_id"), 0)
            elif dataset in ("gpqa", "math500", "hitom"):
                key = (obj.get("id"), 0)
            else:
                continue
            counts[key] += 1
    return counts

# --------------------------------------------------------------------
# Generation core
# --------------------------------------------------------------------

def evaluate_dataset_for_temperature(
    *,
    dataset: str,
    items: List[Any],
    model: LLM,
    tokenizer,
    model_name: str,
    temperature: float,
    num_samples: int,
    output_path: str,
    batch_size: int,
    top_k: int,
    max_tokens: int,
    save_logprobs: bool,
    evaluate_lcb: bool,
    resume_counts: Dict[Tuple, int],
    dataset_specific_paths: Dict[str, str],
    thinking: bool=False
) -> None:
    """
    For a single temperature:
      - Filter items needing more completions
      - Build repeated prompts
      - Generate in batches
      - Append results
    """
    # Determine key function
    def problem_key(item: Any) -> Tuple:
        if dataset == "aime":
            return (item.subset, item.item_id)
        elif dataset == "lcb":
            return (item.question_id, 0)
        elif dataset in ("gpqa", "math500", "hitom"):
            return (item.id, 0)
        else:
            return ("unknown", 0)

    jobs: List[Tuple[Any, int]] = []
    for it in items:
        k = problem_key(it)
        done = resume_counts.get(k, 0)
        need = max(0, num_samples - done)
        if need > 0:
            jobs.append((it, need))

    if not jobs:
        print(f"[{dataset.upper()}][T={temperature:.2f}] All completions already present -> skip")
        return

    print(f"[{dataset.upper()}][T={temperature:.2f}] Need {sum(n for _, n in jobs)} new completions across {len(jobs)} problems")

    prompts: List[str] = []
    meta: List[Dict[str, Any]] = []
    for item, need in jobs:
        base_prompt_text = construct_prompt(
            getattr(item, "question", getattr(item, "prompt", "")),
            model_name,
            dataset
        )
        # For Qwen style chat – rebuild with chat template (lightweight)
        if "Qwen" in model_name:
            # Use chat template
            messages = [{"role": "user", "content": base_prompt_text if thinking else base_prompt_text}]
            prompt_rendered = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            prompt_rendered = base_prompt_text
        k = problem_key(item)
        for _ in range(need):
            prompts.append(prompt_rendered)
            m: Dict[str, Any] = {
                "dataset": dataset,
                "temperature": temperature,
                "key": k,
                "timestamp": now_iso()
            }
            if dataset == "aime":
                m.update({
                    "subset": item.subset,
                    "item_id": item.item_id,
                    "global_id": item.global_id,
                    "gold_answer": item.answer
                })
            elif dataset == "lcb":
                m.update({
                    "question_id": item.question_id,
                    "platform": item.platform,
                    "difficulty": item.difficulty,
                })
            elif dataset == "gpqa":
                m.update({
                    "id": item.id,
                    "correct_answer": item.correct_answer
                })
            elif dataset == "math500":
                m.update({
                    "id": item.id,
                    "answer": item.answer
                })
            elif dataset == "hitom":
                m.update({
                    "id": item.id,
                    "answer": item.answer
                })
            meta.append(m)

    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=1.0,
        max_tokens=max_tokens,
        n=1,
        logprobs=top_k
    )

    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    out_f = open(output_path, "a", encoding="utf-8")

    total = len(prompts)
    for start in tqdm(range(0, total, batch_size), desc=f"{dataset.upper()} T={temperature:.2f}"):
        batch_prompts = prompts[start:start+batch_size]
        batch_meta = meta[start:start+batch_size]
        results = model.generate(batch_prompts, sampling_params)

        for res, m in zip(results, batch_meta):
            out_record: Dict[str, Any] = {
                "dataset": dataset,
                "temperature": m["temperature"],
                "timestamp": m["timestamp"],
            }
            # Dataset-specific fields
            if dataset == "aime":
                out_record.update({
                    "subset": m["subset"],
                    "item_id": m["item_id"],
                    "global_id": m["global_id"],
                    "gold_answer": m["gold_answer"]
                })
            elif dataset == "lcb":
                out_record.update({
                    "question_id": m["question_id"],
                    "platform": m["platform"],
                    "difficulty": m["difficulty"]
                })
            elif dataset == "gpqa":
                out_record.update({
                    "id": m["id"],
                    "correct_answer": m["correct_answer"]
                })
            elif dataset == "math500":
                out_record.update({
                    "id": m["id"],
                    "answer": m["answer"]
                })
            elif dataset == "hitom":
                out_record.update({
                    "id": m["id"],
                    "answer": m["answer"]
                })

            # Generation output
            gen = res.outputs[0]
            text, tokens_info = pack_token_logprobs_vllm(gen)
            out_record["response_text"] = text

            # Correctness
            if dataset == "aime":
                pred_ans = integer_answer_from_text(text)
                out_record["extracted_answer"] = pred_ans
                out_record["is_correct"] = is_correct_aime(out_record["gold_answer"], pred_ans)
            elif dataset == "gpqa":
                pred_choice = gpqa_extract_choice(text)
                out_record["extracted_answer"] = pred_choice
                out_record["is_correct"] = is_correct_gpqa(pred_choice, out_record["correct_answer"])
            elif dataset == "math500":
                # Minimal exact normalized match (boxed if present)
                pred_math = extract_boxed_answer(text) or text.strip()
                out_record["extracted_answer"] = pred_math
                out_record["is_correct"] = is_correct_math(pred_math, out_record["answer"])
            elif dataset == "hitom":
                pred_short = extract_boxed_answer(text) or text.strip()
                out_record["extracted_answer"] = pred_short
                out_record["is_correct"] = is_correct_hitom(pred_short, out_record["answer"])
            elif dataset == "lcb":
                if evaluate_lcb:
                    # Evaluate code correctness
                    # We need an LCBItem-like object; simplest dynamic container
                    obj = LCBItem(
                        question_id=out_record["question_id"],
                        question_title="",
                        platform=out_record["platform"],
                        difficulty=out_record["difficulty"],
                        question_content="",
                        prompt="",  # original prompt not strictly needed here
                        starter_code="",
                        public_test_cases="[]",
                        private_test_cases="[]",
                        metadata="{}"
                    )
                    try:
                        eval_res = evaluate_lcb_response(text, obj, debug=False)
                        out_record["is_correct"] = bool(eval_res.get("all_passed", False))
                    except Exception:
                        out_record["is_correct"] = False
                else:
                    out_record["is_correct"] = False  # correctness deferred / external

            # Entropy
            avg_ent, pos_ents = calculate_entropy(tokens_info)
            out_record["entropy"] = avg_ent
            if save_logprobs:
                out_record["position_entropies"] = pos_ents
                out_record["tokens"] = tokens_info
            else:
                # Keep lightweight
                pass

            out_f.write(json.dumps(out_record, ensure_ascii=False) + "\n")

        out_f.flush()
    out_f.close()
    print(f"[{dataset.upper()}][T={temperature:.2f}] Completed -> {output_path}")

# --------------------------------------------------------------------
# Main
# --------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser(description="Clean multi-dataset evaluator.")
    ap.add_argument("--dataset", type=str, required=True,
                    choices=["aime", "lcb", "gpqa", "math500", "hitom"],
                    help="Dataset to evaluate.")
    ap.add_argument("--aime_version", type=str, default="2025",
                    choices=["2024", "2025"], help="AIME version (when dataset=aime).")
    ap.add_argument("--lcb_jsonl", type=str, default="lcb_v6_with_prompts.jsonl")
    ap.add_argument("--gpqa_jsonl", type=str, default="gpqa_dataset.jsonl")
    ap.add_argument("--math500_jsonl", type=str, default="math500_level5.jsonl")
    ap.add_argument("--hitom_jsonl", type=str, default="hitom_dataset.jsonl")

    ap.add_argument("--model", type=str, required=True, help="HF repo ID or local path.")
    ap.add_argument("--cache_dir", type=str, default=None, help="HF cache / model download dir (optional).")

    ap.add_argument("--temperatures", type=float, nargs="+", required=True,
                    help="List of temperatures to evaluate (e.g. 0.2 0.4 0.6).")
    ap.add_argument("--num_samples", type=int, default=32,
                    help="Per-problem number of completions per temperature.")
    ap.add_argument("--batch_size", type=int, default=16, help="Batch size per generate() call.")
    ap.add_argument("--max_tokens", type=int, default=1024, help="Max new tokens.")
    ap.add_argument("--top_k", type=int, default=20, help="Top-k logprobs to request for entropy.")
    ap.add_argument("--save_logprobs", action="store_true", help="Store token logprobs & per-position entropies.")
    ap.add_argument("--thinking", action="store_true", help="Do not append '/no_think' for Qwen-like models.")
    ap.add_argument("--evaluate_lcb", action="store_true", help="Run code execution judge for LCB outputs.")
    ap.add_argument("--output", type=str, default=None, help="Single output JSONL (only if one temperature).")
    ap.add_argument("--output_dir", type=str, default=None,
                    help="Directory for per-temperature outputs (mutually exclusive with --output unless one temp).")
    return ap.parse_args()

def main():
    args = parse_args()

    if args.output and args.output_dir and len(args.temperatures) > 1:
        raise ValueError("Use either --output (single temperature) or --output_dir (multi), not both.")

    if len(args.temperatures) > 1 and not args.output_dir:
        raise ValueError("For multiple temperatures, please provide --output_dir.")

    if len(args.temperatures) == 1 and not (args.output or args.output_dir):
        # Default single-file name
        args.output = f"{args.dataset}_temp{args.temperatures[0]:.2f}_results.jsonl"

    # Load dataset items
    if args.dataset == "aime":
        items = load_aime_dataset(args.aime_version)
    elif args.dataset == "lcb":
        items = load_lcb_dataset(args.lcb_jsonl)
    elif args.dataset == "gpqa":
        items = load_gpqa_dataset(args.gpqa_jsonl)
    elif args.dataset == "math500":
        items = load_math500_dataset(args.math500_jsonl)
    elif args.dataset == "hitom":
        items = load_hitom_dataset(args.hitom_jsonl)
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    print(f"[INFO] Loaded {len(items)} {args.dataset} items")

    # Initialize model + tokenizer
    print(f"[INFO] Initializing model: {args.model}")
    llm = LLM(
        model=args.model,
        trust_remote_code=True,
        enforce_eager=True,
        download_dir=args.cache_dir
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.model, trust_remote_code=True, cache_dir=args.cache_dir
    )

    for t in args.temperatures:
        if args.output_dir:
            os.makedirs(args.output_dir, exist_ok=True)
            out_path = os.path.join(
                args.output_dir,
                f"{args.dataset}_temp{t:.2f}_results.jsonl"
            )
        else:
            out_path = args.output

        # Resume counts
        resume_counts = count_done_by_key(out_path, args.dataset)
        print(f"[INFO][T={t:.2f}] Found {len(resume_counts)} keys with existing completions (resume-safe)")

        evaluate_dataset_for_temperature(
            dataset=args.dataset,
            items=items,
            model=llm,
            tokenizer=tokenizer,
            model_name=args.model,
            temperature=t,
            num_samples=args.num_samples,
            output_path=out_path,
            batch_size=args.batch_size,
            top_k=args.top_k,
            max_tokens=args.max_tokens,
            save_logprobs=args.save_logprobs,
            evaluate_lcb=args.evaluate_lcb,
            resume_counts=resume_counts,
            dataset_specific_paths={
                "lcb_jsonl": args.lcb_jsonl,
                "gpqa_jsonl": args.gpqa_jsonl,
                "math500_jsonl": args.math500_jsonl,
                "hitom_jsonl": args.hitom_jsonl
            },
            thinking=args.thinking
        )

    print("[DONE] All temperatures processed.")

if __name__ == "__main__":
    main()
