#!/bin/bash
# AIME pass@32 evaluation — baseline vs multi-puzzle epoch 7
# 2 models × 2 tasks = 4 evals, AIME25 first then AIME24
#
# Usage:
#   ./scripts/evals/eval_aime_pass32.sh
#   nohup ./scripts/evals/eval_aime_pass32.sh > logs/aime_pass32.log 2>&1 &

set -euo pipefail

PROJECT_DIR="${PROJECT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)}"
cd "$PROJECT_DIR"

source ~/src/kernels/verl-latest/.venv/bin/activate
echo "Activated verl-latest venv"

export VLLM_USE_TRTLLM_ATTENTION=0
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
export PYTHONUNBUFFERED=1
export WANDB_CONSOLE=off
export VLLM_USE_V1=1

BASE_MODEL="Qwen/Qwen2.5-7B-Instruct"
MULTI_PUZZLE_LORA_EPOCH7="checkpoints/qwen25_7b_multi_puzzle_dsr/lora_epoch_7"
CUSTOM_TASKS_PATH="evaluate/custom_tasks"
OUTPUT_DIR="results/aime_pass32"
SEED=42
GPU_MEMORY_UTILIZATION=0.95
MAX_MODEL_LEN=28000
MAX_LORA_RANK=64

mkdir -p "${OUTPUT_DIR}"

echo "=========================================="
echo "AIME pass@32 — Baseline vs Epoch 7"
echo "=========================================="

run_eval() {
    local model_label=$1
    local task=$2
    local pretrained=$3
    local lora_path="${4:-}"

    local result_dir="${OUTPUT_DIR}/${model_label}/${task}"

    if ls "${result_dir}"/*/results*.json 2>/dev/null | head -1 | grep -q .; then
        echo "[SKIP] ${model_label} / ${task}"
        return 0
    fi

    local model_args="pretrained=${pretrained},tensor_parallel_size=1,gpu_memory_utilization=${GPU_MEMORY_UTILIZATION},max_model_len=${MAX_MODEL_LEN},trust_remote_code=True"
    if [ -n "${lora_path}" ]; then
        model_args="${model_args},lora_local_path=${lora_path},max_lora_rank=${MAX_LORA_RANK}"
    fi

    local log_file="${OUTPUT_DIR}/${model_label}_${task}.log"

    echo "[START] ${model_label} / ${task}  ($(date '+%H:%M:%S'))"
    CUDA_VISIBLE_DEVICES=0 lm_eval run \
        --model vllm \
        --model_args "${model_args}" \
        --include_path ${CUSTOM_TASKS_PATH} \
        --tasks ${task} \
        --batch_size auto \
        --apply_chat_template \
        --seed ${SEED} \
        --output_path "${result_dir}" \
        --log_samples \
        > "${log_file}" 2>&1

    local rc=$?
    if [ ${rc} -eq 0 ]; then
        echo "[DONE]  ${model_label} / ${task}  ($(date '+%H:%M:%S'))"
    else
        echo "[FAIL]  ${model_label} / ${task} (exit ${rc}) — see ${log_file}"
    fi
    return ${rc}
}

MODELS=(
    "baseline|${BASE_MODEL}|"
    "multi_puzzle_epoch7|${BASE_MODEL}|${MULTI_PUZZLE_LORA_EPOCH7}"
)

# AIME25 first (where we saw the signal), then AIME24
for task in aime25_r1_pass32 aime24_r1_pass32; do
    echo ""
    echo "=== Task: ${task} ==="
    for entry in "${MODELS[@]}"; do
        IFS='|' read -r label pretrained lora <<< "${entry}"
        run_eval "${label}" "${task}" "${pretrained}" "${lora}"
    done
done

echo ""
echo "=========================================="
echo "ALL EVALUATIONS COMPLETE"
echo "=========================================="

# =============================================================================
# Rescore with pass@k and Wilson CIs
# =============================================================================
python3 << 'PYEOF'
import json, glob, re, math, os
from pathlib import Path

def extract_boxed_answer(text):
    if not text: return None
    idx = text.rfind(r"\boxed")
    if idx == -1: return None
    start = text.find("{", idx)
    if start == -1: return None
    depth, end = 0, start
    for i in range(start, len(text)):
        if text[i] == "{": depth += 1
        elif text[i] == "}":
            depth -= 1
            if depth == 0: end = i; break
    return text[start+1:end].strip()

def extract_last_number(text):
    numbers = re.findall(r"\b\d+\b", text)
    return numbers[-1] if numbers else None

def normalize(answer):
    if answer is None: return ""
    return re.sub(r"\s+", "", str(answer)).lower()

def is_correct(response, target):
    pred = extract_boxed_answer(response) or extract_last_number(response)
    gt = extract_boxed_answer(target) or extract_last_number(target) or target
    return normalize(pred) == normalize(gt)

def wilson_ci(s, n, z=1.96):
    if n == 0: return 0, 0, 0
    p = s / n
    d = 1 + z**2 / n
    c = (p + z**2 / (2*n)) / d
    m = z * math.sqrt(p*(1-p)/n + z**2/(4*n**2)) / d
    return p, max(0, c-m), min(1, c+m)

output_dir = Path("results/aime_pass32")
models = ["baseline", "multi_puzzle_epoch7"]
labels = {"baseline": "Qwen2.5-7B Base", "multi_puzzle_epoch7": "Multi-Puzzle Ep7"}
tasks = ["aime25_r1_pass32", "aime24_r1_pass32"]
task_short = {"aime25_r1_pass32": "AIME25", "aime24_r1_pass32": "AIME24"}

print(f"\n{'='*100}")
print(f"  AIME pass@32 — avg pass@1 with 95% Wilson CIs (30 problems × 32 samples = 960 trials)")
print(f"{'='*100}")

for task in tasks:
    short = task_short[task]
    print(f"\n  {short}:")
    results = {}
    for model in models:
        sf = glob.glob(str(output_dir / model / task / "**" / "samples_*.jsonl"), recursive=True)
        if not sf:
            print(f"    {labels[model]:<22}     — (no results)")
            continue

        total_correct = 0
        total_samples = 0
        doc_resps = {}
        with open(sf[0]) as f:
            for line in f:
                s = json.loads(line)
                doc_id = s.get("doc_id")
                target = s.get("target", "")
                resps = s.get("resps", [[]])
                responses = resps[0] if resps and isinstance(resps[0], list) else resps
                doc_resps[doc_id] = (responses, target)

        for doc_id, (responses, target) in sorted(doc_resps.items()):
            for r in responses:
                if is_correct(r, target):
                    total_correct += 1
                total_samples += 1

        p, lo, hi = wilson_ci(total_correct, total_samples)
        results[model] = (p, lo, hi, total_correct, total_samples)
        print(f"    {labels[model]:<22} {p*100:>6.1f}%  [{lo*100:.1f}%-{hi*100:.1f}%]  ({total_correct}/{total_samples})")

        # Also compute pass@8 and pass@32
        p8_pass = sum(1 for _, (rs, t) in doc_resps.items() if any(is_correct(r, t) for r in rs[:8]))
        p32_pass = sum(1 for _, (rs, t) in doc_resps.items() if any(is_correct(r, t) for r in rs))
        n_docs = len(doc_resps)
        _, p8_lo, p8_hi = wilson_ci(p8_pass, n_docs)
        _, p32_lo, p32_hi = wilson_ci(p32_pass, n_docs)
        print(f"    {'':22} pass@8:  {p8_pass/n_docs*100:>5.1f}% [{p8_lo*100:.1f}%-{p8_hi*100:.1f}%]  pass@32: {p32_pass/n_docs*100:>5.1f}% [{p32_lo*100:.1f}%-{p32_hi*100:.1f}%]")

    # Overlap check
    if len(results) == 2:
        b_p, b_lo, b_hi, _, _ = results["baseline"]
        e_p, e_lo, e_hi, _, _ = results["multi_puzzle_epoch7"]
        overlaps = e_lo < b_hi and b_lo < e_hi
        delta = (e_p - b_p) * 100
        verdict = "OVERLAPPING (not significant)" if overlaps else "NON-OVERLAPPING (significant!)"
        print(f"\n    Delta: {delta:+.1f}pp — CIs {verdict}")

print(f"\n{'='*100}")

# Save summary
summary = {}
for model in models:
    summary[model] = {}
    for task in tasks:
        sf = glob.glob(str(output_dir / model / task / "**" / "samples_*.jsonl"), recursive=True)
        if sf:
            total_c = 0; total_n = 0
            with open(sf[0]) as f:
                for line in f:
                    s = json.loads(line)
                    resps = s["resps"][0] if s["resps"] and isinstance(s["resps"][0], list) else s["resps"]
                    for r in resps:
                        if is_correct(r, s["target"]): total_c += 1
                        total_n += 1
            summary[model][task] = {"correct": total_c, "total": total_n, "avg_pass1": round(total_c/total_n, 4)}

with open(str(output_dir / "summary.json"), "w") as f:
    json.dump(summary, f, indent=2)
print(f"\nSummary saved: {output_dir}/summary.json")
PYEOF
