#!/usr/bin/env python3
"""Compute SFT perplexity on correct responses from lm_eval JSONL files.

Uses lm_eval's scoring functions for correctness, vLLM for prompt_logprobs,
and multiprocessing across GPUs for speed.

Usage:
    python scripts/analysis/compute_sft_ppl.py \
        --sft_model checkpoints/olmo3_7b_multi_puzzle_dsr_v2/merged_ep5_fp32 \
        --jsonl_files results/foo/samples_*.jsonl results/bar/samples_*.jsonl \
        --task_type aime \
        --output results/novelty_signal/per_response_ppl.jsonl \
        --num_gpus 4
"""

import argparse
import json
import os
import re
import sys
from multiprocessing import Process, Queue
from pathlib import Path

# Add project root + lm_eval tasks for scoring
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

# Import lm_eval AIME scoring
sys.path.insert(0, str(PROJECT_ROOT / "verl-vllm012" / "lib" / "python3.12" / "site-packages" / "lm_eval" / "tasks" / "aime"))
import utils as aime_utils

# Import puzzle scoring
sys.path.insert(0, str(PROJECT_ROOT / "evaluate" / "custom_tasks" / "puzzle" / "_common"))
from scoring import make_scorer


def score_response(response: str, doc: dict, task_type: str) -> int:
    """Score a single response using the appropriate task scorer."""
    if task_type in ("aime", "aime24", "aime25", "math"):
        result = aime_utils.process_results(doc, [response])
        return result["exact_match"]
    else:
        # Puzzle scoring
        puzzle_type = task_type.split("_")[0]  # e.g. "bridges" from "bridges_5x5dm"
        _, score_fn = make_scorer(puzzle_type if puzzle_type in ("bridges",) else "generic")
        return score_fn(response, doc)


def load_and_filter_correct(jsonl_paths: list[str], task_type: str, label: str) -> list[dict]:
    """Load JSONLs, deduplicate, score, return correct-only records."""
    records = []
    for path in jsonl_paths:
        with open(path) as f:
            for line in f:
                d = json.loads(line)
                doc = d["doc"]
                seen = set()
                for j, r in enumerate(d["resps"][0]):
                    h = hash(r)
                    if h in seen:
                        continue
                    seen.add(h)
                    correct = score_response(r, doc, task_type)
                    if correct:
                        records.append({
                            "source": label,
                            "file": Path(path).name,
                            "doc_id": d["doc_id"],
                            "resp_idx": j,
                            "resp_len": len(r),
                            "response": r,
                        })
    print(f"  {label}: {len(records)} correct unique responses")
    return records


def gpu_worker(gpu_id: int, model_path: str, texts: list[str], indices: list[int],
               max_model_len: int, result_queue: Queue):
    """Worker process: compute prompt_logprobs on one GPU."""
    import numpy as np
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    # Lazy imports inside worker to avoid GPU init in parent
    from vllm import LLM, SamplingParams

    llm = LLM(model=model_path, gpu_memory_utilization=0.90, max_model_len=max_model_len,
              trust_remote_code=True, enforce_eager=True)
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    sp = SamplingParams(max_tokens=1, prompt_logprobs=1, temperature=0)
    results = {}
    batch_size = 8
    skipped = 0

    for i in range(0, len(texts), batch_size):
        batch_texts = []
        batch_indices = []
        for j in range(i, min(i + batch_size, len(texts))):
            toks = tokenizer.encode(texts[j])
            if len(toks) > max_model_len - 10:
                results[indices[j]] = None
                skipped += 1
            else:
                batch_texts.append(texts[j])
                batch_indices.append(indices[j])

        if batch_texts:
            outputs = llm.generate(batch_texts, sp)
            for idx, out in zip(batch_indices, outputs):
                lps = out.prompt_logprobs
                if lps is None:
                    results[idx] = None
                    continue
                # Save full per-token neg-logprobs (not just mean)
                token_nlps = []
                for lp in lps[1:]:
                    if lp is not None:
                        for _, logprob_obj in lp.items():
                            token_nlps.append(-logprob_obj.logprob)
                            break
                results[idx] = token_nlps

        done = min(i + batch_size, len(texts))
        if done % 32 < batch_size:
            print(f"  GPU {gpu_id}: {done}/{len(texts)}", flush=True)

    print(f"  GPU {gpu_id}: done ({skipped} skipped)")
    result_queue.put(results)


def compute_ppl_multi_gpu(records: list[dict], model_path: str, num_gpus: int,
                          max_model_len: int) -> list[dict]:
    """Compute SFT perplexity across multiple GPUs."""
    texts = [r["response"] for r in records]
    indices = list(range(len(texts)))

    # Split across GPUs (interleaved for balance)
    gpu_texts = [[] for _ in range(num_gpus)]
    gpu_indices = [[] for _ in range(num_gpus)]
    for i, (t, idx) in enumerate(zip(texts, indices)):
        g = i % num_gpus
        gpu_texts[g].append(t)
        gpu_indices[g].append(idx)

    result_queue = Queue()
    procs = []
    for g in range(num_gpus):
        if not gpu_texts[g]:
            continue
        p = Process(target=gpu_worker,
                    args=(g, model_path, gpu_texts[g], gpu_indices[g], max_model_len, result_queue))
        p.start()
        procs.append(p)

    # Collect results
    all_results = {}
    for _ in procs:
        all_results.update(result_queue.get())

    for p in procs:
        p.join()

    # Attach per-token logprobs and summary stats to records
    import numpy as np
    for i, r in enumerate(records):
        token_nlps = all_results.get(i)
        del r["response"]  # don't save full text
        if token_nlps is None:
            r["sft_ppl"] = float("nan")
            r["token_neg_logprobs"] = []
            r["n_tokens"] = 0
            continue
        arr = np.array(token_nlps)
        r["token_neg_logprobs"] = [round(float(v), 6) for v in token_nlps]  # full per-token data
        r["n_tokens"] = len(token_nlps)
        r["sft_ppl"] = float(np.mean(arr))
        r["top100_ppl"] = float(np.mean(np.sort(arr)[-100:])) if len(arr) >= 100 else float(np.mean(arr))
        r["top200_ppl"] = float(np.mean(np.sort(arr)[-200:])) if len(arr) >= 200 else float(np.mean(arr))
        r["top500_ppl"] = float(np.mean(np.sort(arr)[-500:])) if len(arr) >= 500 else float(np.mean(arr))
        r["p95_ppl"] = float(np.percentile(arr, 95))
        r["p99_ppl"] = float(np.percentile(arr, 99))

    return records


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--sft_model", required=True)
    parser.add_argument("--jsonl_files", nargs="+", required=True, help="Glob-expanded list of JSONL files")
    parser.add_argument("--labels", nargs="+", required=True, help="Label for each jsonl file (same order)")
    parser.add_argument("--task_type", required=True, help="aime, olymp_math, bridges, pattern, etc.")
    parser.add_argument("--output", required=True)
    parser.add_argument("--num_gpus", type=int, default=4)
    parser.add_argument("--max_model_len", type=int, default=28000)
    args = parser.parse_args()

    assert len(args.jsonl_files) == len(args.labels), "Must have one label per jsonl file"

    # Load and score
    all_records = []
    for path, label in zip(args.jsonl_files, args.labels):
        records = load_and_filter_correct([path], args.task_type, label)
        all_records.extend(records)

    print(f"\nTotal correct responses to score: {len(all_records)}")
    if not all_records:
        print("No correct responses found!")
        sys.exit(1)

    # Compute perplexity
    print(f"\nComputing SFT perplexity on {len(all_records)} responses across {args.num_gpus} GPUs...")
    all_records = compute_ppl_multi_gpu(all_records, args.sft_model, args.num_gpus, args.max_model_len)

    # Save
    os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
    with open(args.output, "w") as f:
        for r in all_records:
            f.write(json.dumps(r) + "\n")
    print(f"\nSaved {len(all_records)} records to {args.output}")

    # Quick summary
    import numpy as np
    from collections import defaultdict

    by_source = defaultdict(list)
    by_source_topk = defaultdict(lambda: defaultdict(list))
    for r in all_records:
        if np.isnan(r["sft_ppl"]):
            continue
        by_source[r["source"]].append(r)
        for k in ["top100_ppl", "top200_ppl", "top500_ppl"]:
            by_source_topk[r["source"]][k].append(r[k])

    print(f"\n{'='*60}")
    print("CORRECT-ONLY SFT PERPLEXITY (mean vs top-k)")
    print(f"{'='*60}")
    for source in sorted(by_source):
        recs = by_source[source]
        ppls = np.array([r["sft_ppl"] for r in recs])
        print(f"\n  {source}: N={len(ppls)}")
        print(f"    Mean ppl:    {np.mean(ppls):.4f} ± {np.std(ppls):.4f}")
        for k in ["top100_ppl", "top200_ppl", "top500_ppl"]:
            vals = np.array(by_source_topk[source][k])
            print(f"    {k}: {np.mean(vals):.4f} ± {np.std(vals):.4f}")

        # Within-prompt variance
        doc_groups = defaultdict(list)
        doc_groups_topk = defaultdict(lambda: defaultdict(list))
        for r in recs:
            doc_groups[r["doc_id"]].append(r["sft_ppl"])
            for k in ["top100_ppl", "top200_ppl", "top500_ppl"]:
                doc_groups_topk[r["doc_id"]][k].append(r[k])

        mean_stds = [np.std(v) for v in doc_groups.values() if len(v) >= 2]
        if mean_stds:
            print(f"    Within-prompt std (mean ppl): {np.mean(mean_stds):.4f}")
            for k in ["top100_ppl", "top200_ppl", "top500_ppl"]:
                topk_stds = [np.std(v[k]) for v in doc_groups_topk.values() if len(v[k]) >= 2]
                if topk_stds:
                    ratio = np.mean(topk_stds) / np.mean(mean_stds)
                    print(f"    Within-prompt std ({k}): {np.mean(topk_stds):.4f}  ({ratio:.1f}x amplification)")


if __name__ == "__main__":
    main()
