#!/usr/bin/env python3
"""
Compute pass@k and cons@k (majority voting) from lm_eval --log_samples output.
Scores all responses in the 'resps' field (not just the filtered first one).

NOTE: lm_eval with repeats=N only scores the first response (take_first filter).
This script properly scores ALL N responses for accurate metrics.

Scoring: automatically resolves the task's scorer from evaluate/custom_tasks/<task>/utils.py
when available (e.g., OlymMATH uses math_verify for symbolic equivalence). Falls back to
simple string matching for tasks without a custom scorer.

Parallel scoring: use --workers N to parallelize math_verify calls across N processes.
Each worker imports the scorer independently, avoiding GIL limitations.

Usage:
    python scripts/evals/compute_pass_at_k.py results/pass_at_k/
    python scripts/evals/compute_pass_at_k.py results/pass_at_k/ --k_values 1,2,4,8,16,32,64
    python scripts/evals/compute_pass_at_k.py results/pass_at_k/ --cons_k 8,64 --json_output summary.json
    python scripts/evals/compute_pass_at_k.py results/pass_at_k/ --workers 8  # parallel scoring
"""

import json
import glob
import math
import sys
import re
import argparse
import importlib
import importlib.util
import os
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path


# ---------------------------------------------------------------------------
# Task scorer resolution — import _score_single from lm_eval task utils
# ---------------------------------------------------------------------------

# Map task name prefixes to their utils.py file path (relative to project root)
_TASK_TO_UTILS = {
    "olymp_math": "evaluate/custom_tasks/olymp_math/utils.py",
    "math500": "evaluate/custom_tasks/math500/utils.py",
    "aime24": "evaluate/custom_tasks/aime24/utils.py",
    "aime25": "evaluate/custom_tasks/aime25/utils.py",
    "beyondaime": "evaluate/custom_tasks/beyondaime/utils.py",
    "hmmt": "evaluate/custom_tasks/hmmt/utils.py",
    # Puzzle tasks (new unified structure)
    "bridges": "evaluate/custom_tasks/puzzle/bridges/utils.py",
    "galaxies": "evaluate/custom_tasks/puzzle/galaxies/utils.py",
    "loopy": "evaluate/custom_tasks/puzzle/loopy/utils.py",
    "pattern": "evaluate/custom_tasks/puzzle/pattern/utils.py",
    "undead": "evaluate/custom_tasks/puzzle/undead/utils.py",
    "omega": "evaluate/custom_tasks/omega/utils.py",
    "countdown": "evaluate/custom_tasks/puzzle/countdown/utils.py",
}

_scorer_cache = {}
_project_root = str(Path(__file__).resolve().parent.parent.parent)

# ---------------------------------------------------------------------------
# Parallel scoring — worker process state
# ---------------------------------------------------------------------------
_worker_scorer = None   # set by _init_scorer_worker in each pool process


def _init_scorer_worker(utils_path: str):
    """Pool initializer: import the task scorer once per worker process."""
    global _worker_scorer
    if not utils_path:
        return
    try:
        spec = importlib.util.spec_from_file_location("_task_scorer", utils_path)
        mod = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod)
        _worker_scorer = getattr(mod, "_score_single", None)
    except Exception as e:
        print(f"  Worker: could not import scorer from {utils_path}: {e}", file=sys.stderr)


def _score_doc_parallel(args):
    """Score all responses for a single doc. Runs in a worker process."""
    responses, doc, target = args
    if _worker_scorer:
        return [bool(_worker_scorer(r, doc)) for r in responses]
    # fallback — no task scorer available, use JSONL target field
    return [is_correct(r, target) for r in responses]


def _resolve_scorer(task_name: str):
    """Try to import _score_single from the task's utils module.
    Returns a function (model_output: str, doc: dict) -> int, or None."""
    if task_name in _scorer_cache:
        return _scorer_cache[task_name]

    # Match task name to utils file
    for prefix, utils_relpath in _TASK_TO_UTILS.items():
        if prefix in task_name:
            utils_path = Path(_project_root) / utils_relpath
            if not utils_path.exists():
                break
            try:
                spec = importlib.util.spec_from_file_location(
                    f"task_utils_{prefix}", str(utils_path)
                )
                mod = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(mod)
                scorer = getattr(mod, "_score_single", None)
                if scorer:
                    _scorer_cache[task_name] = scorer
                    return scorer
                else:
                    print(f"  Warning: {utils_relpath} has no _score_single function", file=sys.stderr)
            except Exception as e:
                print(f"  Warning: could not import scorer for {task_name}: {e}", file=sys.stderr)
                break

    _scorer_cache[task_name] = None
    return None


def _resolve_scorer_path(task_name: str) -> str:
    """Return the absolute path to the task's utils.py, or empty string if not found.
    Used to pass to worker processes (paths are picklable, functions are not)."""
    for prefix, utils_relpath in _TASK_TO_UTILS.items():
        if prefix in task_name:
            utils_path = Path(_project_root) / utils_relpath
            if utils_path.exists():
                return str(utils_path)
    return ""


def _extract_task_name(samples_file: str) -> str:
    """Extract task name from samples filename, e.g., 'samples_aime24_r1_avg8_...' -> 'aime24_r1_avg8'."""
    fname = Path(samples_file).stem  # e.g., samples_aime24_r1_avg8_2026-...
    # Remove 'samples_' prefix and timestamp suffix
    m = re.match(r"samples_(.+?)_\d{4}-\d{2}-\d{2}", fname)
    if m:
        return m.group(1)
    # Fallback: just strip 'samples_'
    return fname.replace("samples_", "")


# ---------------------------------------------------------------------------
# Default answer extraction (fallback when no task scorer available)
# ---------------------------------------------------------------------------

def extract_boxed_answer(text: str):
    """Extract last \\boxed{} content, handling nested braces."""
    if not text:
        return None
    idx = text.rfind(r"\boxed")
    if idx == -1:
        return None
    start = text.find("{", idx)
    if start == -1:
        return None
    depth, end = 0, start
    for i in range(start, len(text)):
        if text[i] == "{":
            depth += 1
        elif text[i] == "}":
            depth -= 1
            if depth == 0:
                end = i
                break
    return text[start + 1:end].strip()


def extract_last_number(text: str):
    """Fallback: extract last integer from text."""
    numbers = re.findall(r"\b\d+\b", text)
    return numbers[-1] if numbers else None


def normalize(answer: str) -> str:
    if answer is None:
        return ""
    return re.sub(r"\s+", "", str(answer)).lower()


def get_prediction(response: str) -> str:
    """Extract and normalize prediction from a response."""
    pred = extract_boxed_answer(response) or extract_last_number(response)
    return normalize(pred)


def is_correct(response: str, target: str) -> bool:
    pred = get_prediction(response)
    gt = normalize(extract_boxed_answer(target) or extract_last_number(target) or target)
    return pred == gt and pred != ""


# ---------------------------------------------------------------------------
# pass@k computation
# ---------------------------------------------------------------------------

def _pass_at_k_combinatorial(n: int, c: int, k: int) -> float:
    """Combinatorial pass@k estimator from Chen et al. (Codex paper).
    Uses all n samples to estimate pass@k, not just the first k.
    n = total samples, c = number correct, k = target k.
    Returns: 1 - C(n-c, k) / C(n, k)
    """
    if n - c < k:
        return 1.0
    return 1.0 - math.comb(n - c, k) / math.comb(n, k)


def pass_at_k(correct_lists: list[list[bool]], k: int) -> float:
    """Unbiased pass@k using combinatorial estimator (Chen et al. 2021).
    Uses ALL samples per doc to estimate pass@k, not just the first k.
    Falls back to naive (first-k) when n == k."""
    if not correct_lists:
        return 0.0
    scores = []
    for correctness in correct_lists:
        n = len(correctness)
        c = sum(correctness)
        if n <= k:
            # Not enough samples — fall back to naive
            scores.append(1.0 if any(correctness[:k]) else 0.0)
        else:
            scores.append(_pass_at_k_combinatorial(n, c, k))
    return sum(scores) / len(scores)


# ---------------------------------------------------------------------------
# cons@k (majority voting) computation
# ---------------------------------------------------------------------------

def cons_at_k_single(responses: list[str], target: str, k: int, scorer=None, doc=None) -> int:
    """Majority voting over first k responses for a single doc. Returns 1 if correct, 0 otherwise.
    If scorer is provided, uses it to check if the majority answer is correct."""
    preds = [get_prediction(r) for r in responses[:k]]
    preds = [p for p in preds if p]  # filter empty predictions
    if not preds:
        return 0
    most_common_pred = Counter(preds).most_common(1)[0][0]
    if scorer and doc:
        # Reconstruct a response with just the majority answer in \boxed{}
        # and score it with the task scorer
        synthetic_response = f"\\boxed{{{most_common_pred}}}"
        return scorer(synthetic_response, doc)
    gt = normalize(extract_boxed_answer(target) or extract_last_number(target) or target)
    return 1 if most_common_pred == gt else 0


def cons_at_k(responses_per_doc: list[tuple[list[str], str]], k: int) -> float:
    """Fraction of docs where majority vote over first k responses is correct."""
    if not responses_per_doc:
        return 0.0
    correct = sum(cons_at_k_single(resps, target, k) for resps, target in responses_per_doc)
    return correct / len(responses_per_doc)


# ---------------------------------------------------------------------------
# Loading and scoring
# ---------------------------------------------------------------------------

def load_and_score(samples_file: str, num_workers: int = 1) -> tuple[list[list[bool]], list[tuple[list[str], str]], int, list]:
    """
    Load samples jsonl. For each doc, score all responses in 'resps'.
    Uses the task's custom scorer (via _score_single) when available,
    falls back to simple string matching otherwise.

    Args:
        num_workers: Number of parallel processes for scoring. >1 uses ProcessPoolExecutor
                     with per-process scorer imports (avoids GIL for CPU-bound math_verify).

    Returns (correct_per_doc, responses_per_doc, n_repeats, doc_ids).
    """
    task_name = _extract_task_name(samples_file)
    scorer_path = _resolve_scorer_path(task_name)

    if scorer_path:
        matched = next((p for p in _TASK_TO_UTILS if p in task_name), "?")
        print(f"  (using task scorer from {_TASK_TO_UTILS.get(matched, '?')})", file=sys.stderr)

    # Load all docs
    doc_resps = {}   # doc_id -> (list_of_responses, target, doc)
    with open(samples_file) as f:
        for line in f:
            s = json.loads(line)
            doc_id = s.get("doc_id", s.get("idx"))
            target = s.get("target", "")
            doc = s.get("doc", {})
            resps = s.get("resps", [[]])
            # resps is [[resp1, resp2, ...]] — flatten
            if resps and isinstance(resps[0], list):
                responses = resps[0]
            else:
                responses = resps
            doc_resps[doc_id] = (responses, target, doc)

    sorted_items = sorted(doc_resps.items())

    # Score — parallel or serial
    if num_workers > 1 and scorer_path:
        n_docs = len(sorted_items)
        total_resps = sum(len(responses) for responses, _, _ in (v for _, v in sorted_items))
        print(f"  Scoring {n_docs} docs × {total_resps // max(n_docs, 1)} resps with {num_workers} workers ...", file=sys.stderr)

        work_items = [(responses, doc, target) for _, (responses, target, doc) in sorted_items]
        with ProcessPoolExecutor(
            max_workers=num_workers,
            initializer=_init_scorer_worker,
            initargs=(scorer_path,),
        ) as pool:
            all_correctness = list(pool.map(_score_doc_parallel, work_items, chunksize=4))

        correct_per_doc = []
        responses_per_doc = []
        doc_ids = []
        n_repeats = 0
        for (doc_id, (responses, target, _doc)), correctness in zip(sorted_items, all_correctness):
            correct_per_doc.append(correctness)
            responses_per_doc.append((responses, target))
            doc_ids.append(doc_id)
            n_repeats = max(n_repeats, len(responses))
    else:
        # Serial path (original behavior, or no task scorer)
        scorer = _resolve_scorer(task_name) if scorer_path else None
        correct_per_doc = []
        responses_per_doc = []
        doc_ids = []
        n_repeats = 0
        for doc_id, (responses, target, doc) in sorted_items:
            if scorer:
                correctness = [bool(scorer(r, doc)) for r in responses]
            else:
                correctness = [is_correct(r, target) for r in responses]
            correct_per_doc.append(correctness)
            responses_per_doc.append((responses, target))
            doc_ids.append(doc_id)
            n_repeats = max(n_repeats, len(responses))

    return correct_per_doc, responses_per_doc, n_repeats, doc_ids


def summarize(output_dir: str, k_values: list[int], cons_k_values: list[int] = None,
              per_problem: bool = False, json_output: str = None, num_workers: int = 1):
    output_dir = Path(output_dir)
    samples_files = sorted(glob.glob(str(output_dir / "**" / "samples_*.jsonl"), recursive=True))

    if not samples_files:
        print(f"No samples_*.jsonl found under {output_dir}")
        return

    # Build header
    k_headers = [f"pass@{k}" for k in k_values]
    cons_headers = [f"cons@{k}" for k in (cons_k_values or [])]
    all_headers = k_headers + cons_headers
    header = f"{'Label / Task':<55} " + " ".join(f"{h:>8}" for h in all_headers) + f" {'n_docs':>7} {'n_reps':>7}"
    print(f"\n{header}")
    print("-" * len(header))

    all_results = {}

    for sf in samples_files:
        parts = Path(sf).parts
        # Extract label/task from path
        try:
            idx = parts.index(output_dir.name)
            label = parts[idx + 1]
            task = parts[idx + 2]
        except (ValueError, IndexError):
            label = Path(sf).parent.parent.name
            task = Path(sf).parent.name

        # Extract samples task name from filename for unique keying
        # e.g. "samples_olymp_math_easy_pass32_2026-..." -> "olymp_math_easy_pass32"
        samples_task = _extract_task_name(sf)

        correct_per_doc, responses_per_doc, n_reps, doc_ids = load_and_score(sf, num_workers=num_workers)
        if not correct_per_doc:
            print(f"{label}/{task:<55} (no data)")
            continue

        # Compute pass@k for each k
        pk_values = {}
        pk_strs = []
        for k in k_values:
            effective_k = min(k, n_reps)
            pk = pass_at_k(correct_per_doc, effective_k)
            pk_values[k] = pk
            pk_strs.append(f"{pk:>8.1%}")

        # Compute cons@k for each k
        ck_values = {}
        ck_strs = []
        for k in (cons_k_values or []):
            effective_k = min(k, n_reps)
            ck = cons_at_k(responses_per_doc, effective_k)
            ck_values[k] = ck
            ck_strs.append(f"{ck:>8.1%}")

        tag = f"{label}/{task}"
        print(f"{tag:<55} " + " ".join(pk_strs + ck_strs) + f" {len(correct_per_doc):>7} {n_reps:>7}")

        # Use samples_task in key to avoid collisions when multiple samples files
        # share the same parent directory (e.g. easy + hard in one checkpoint dir)
        result_key = f"{label}/{samples_task}"
        per_problem_data = []
        for i, (doc_id, correctness) in enumerate(zip(doc_ids, correct_per_doc)):
            n_correct = sum(correctness)
            n_total = len(correctness)
            doc_pk = {}
            for k in k_values:
                doc_pk[str(k)] = any(correctness[:min(k, n_total)])
            doc_ck = {}
            for k in (cons_k_values or []):
                resps, target = responses_per_doc[i]
                doc_ck[str(k)] = cons_at_k_single(resps, target, min(k, n_total))
            per_problem_data.append({
                "doc_id": doc_id,
                "n_correct": n_correct,
                "n_total": n_total,
                "pass_at_k": doc_pk,
                "cons_at_k": doc_ck,
            })

        if per_problem:
            print(f"\n  Per-problem breakdown for {tag}:")
            print(f"  {'doc_id':>8}  {'correct':>8}  {'total':>6}")
            print(f"  {'-'*26}")
            for entry in per_problem_data:
                print(f"  {entry['doc_id']:>8}  {entry['n_correct']:>8}  {entry['n_total']:>6}")
            print()

        all_results[result_key] = {
            "pass_at_k": {str(k): round(v, 4) for k, v in pk_values.items()},
            "cons_at_k": {str(k): round(v, 4) for k, v in ck_values.items()},
            "n_docs": len(correct_per_doc),
            "n_reps": n_reps,
            "per_problem": per_problem_data,
        }

    if json_output:
        json_path = Path(json_output)
        json_path.parent.mkdir(parents=True, exist_ok=True)
        with open(json_path, "w") as f:
            json.dump(all_results, f, indent=2)
        print(f"\nJSON summary saved: {json_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("output_dir", help="Results directory to scan for samples_*.jsonl")
    parser.add_argument("--k_values", type=str, default="1,4,8",
                        help="Comma-separated k values for pass@k (default: 1,4,8)")
    parser.add_argument("--cons_k", type=str, default=None,
                        help="Comma-separated k values for cons@k / majority voting (e.g. 8,64)")
    parser.add_argument("--per_problem", action="store_true",
                        help="Print per-problem correct counts")
    parser.add_argument("--json_output", type=str, default=None,
                        help="Path to write JSON summary with pass@k, cons@k, and per-problem data")
    parser.add_argument("--workers", type=int, default=1,
                        help="Number of parallel worker processes for scoring (default: 1 = serial). "
                             "Use 4-8 for math_verify-based tasks like OlymMATH.")
    args = parser.parse_args()
    k_list = [int(x) for x in args.k_values.split(",")]
    cons_list = [int(x) for x in args.cons_k.split(",")] if args.cons_k else None
    summarize(args.output_dir, k_list, cons_list, args.per_problem, args.json_output, args.workers)
