"""Pass@k computation and gain calculation.

Imports scoring directly from compute_pass_at_k.py — no duplication.
Adds gain computation between (SFT, GSPO) checkpoint pairs.
"""
from __future__ import annotations

import sys
from pathlib import Path
from typing import Optional

import pandas as pd

# Import from existing infrastructure
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(_PROJECT_ROOT / "scripts" / "evals"))
from compute_pass_at_k import (
    _extract_task_name,
    load_and_score,
    pass_at_k,
)


def compute_passk_from_jsonl(
    jsonl_path: str | Path,
    k_values: list[int] = (1, 8, 32),
) -> dict:
    """Compute pass@k from a single lm_eval samples JSONL file.

    Returns dict with:
        task_name: str
        n_docs: int
        n_reps: int
        pass_at_k: {k: float}
        per_doc_correct: list[list[bool]]
    """
    jsonl_path = str(jsonl_path)
    task_name = _extract_task_name(jsonl_path)
    correct_per_doc, _, n_reps, doc_ids = load_and_score(jsonl_path)

    pk = {}
    for k in k_values:
        effective_k = min(k, n_reps)
        pk[k] = pass_at_k(correct_per_doc, effective_k)

    return {
        "task_name": task_name,
        "n_docs": len(correct_per_doc),
        "n_reps": n_reps,
        "pass_at_k": pk,
        "per_doc_correct": correct_per_doc,
    }


def compute_passk_for_dir(
    results_dir: str | Path,
    k_values: list[int] = (1, 8, 32),
    task_filter: Optional[str] = None,
) -> dict[str, dict]:
    """Compute pass@k for all tasks in a results directory.

    Args:
        results_dir: Path to results directory containing samples_*.jsonl
        k_values: k values to compute
        task_filter: If set, only include tasks containing this substring

    Returns dict keyed by task_name -> pass@k results.
    """
    results_dir = Path(results_dir)
    jsonl_files = sorted(results_dir.rglob("samples_*.jsonl"))

    results = {}
    for jf in jsonl_files:
        task_name = _extract_task_name(str(jf))
        if task_filter and task_filter not in task_name:
            continue
        results[task_name] = compute_passk_from_jsonl(jf, k_values)

    return results


def compute_passk_gains(
    gspo_results: dict[str, dict],
    sft_results: dict[str, dict],
    k: int = 32,
) -> dict[str, float]:
    """Compute per-task pass@k gains: gspo - sft.

    Returns dict keyed by task_name -> gain.
    Only includes tasks present in both.
    """
    gains = {}
    for task_name in gspo_results:
        if task_name in sft_results:
            gspo_pk = gspo_results[task_name]["pass_at_k"].get(k, 0.0)
            sft_pk = sft_results[task_name]["pass_at_k"].get(k, 0.0)
            gains[task_name] = gspo_pk - sft_pk
    return gains


def build_gain_table(
    checkpoint_configs: list[dict],
    k_values: list[int] = (1, 8, 32),
    math_task_filter: str = "olymp_math_hard",
    puzzle_tasks: Optional[list[str]] = None,
) -> pd.DataFrame:
    """Build a table of puzzle and math pass@k gains for all checkpoint pairs.

    Args:
        checkpoint_configs: List of dicts with keys:
            - checkpoint_id: str
            - sft_base: str ('v1' or 'v2')
            - math_results: str (path to results dir)
            - puzzle_results: str or None
        k_values: k values to compute
        math_task_filter: substring to filter math tasks
        puzzle_tasks: list of puzzle task names to include (all if None)

    Returns DataFrame with one row per checkpoint (or per checkpoint x task
    for per-task granularity), columns for puzzle and math gains.
    """
    rows = []

    # First, compute SFT baselines (need to be provided separately)
    # This function expects pre-computed baseline results to be passed in

    for cfg in checkpoint_configs:
        ckpt_id = cfg["checkpoint_id"]
        sft_base = cfg["sft_base"]

        # Math gains
        math_dir = cfg.get("math_results")
        sft_math_dir = cfg.get("sft_math_results")
        math_gains = {}
        if math_dir and sft_math_dir:
            gspo_math = compute_passk_for_dir(math_dir, k_values, math_task_filter)
            sft_math = compute_passk_for_dir(sft_math_dir, k_values, math_task_filter)
            math_gains = compute_passk_gains(gspo_math, sft_math)

        # Puzzle gains
        puzzle_dir = cfg.get("puzzle_results")
        sft_puzzle_dir = cfg.get("sft_puzzle_results")
        puzzle_gains = {}
        if puzzle_dir and sft_puzzle_dir:
            gspo_puzzle = compute_passk_for_dir(puzzle_dir, k_values)
            sft_puzzle = compute_passk_for_dir(sft_puzzle_dir, k_values)
            if puzzle_tasks:
                gspo_puzzle = {k: v for k, v in gspo_puzzle.items()
                               if any(t in k for t in puzzle_tasks)}
                sft_puzzle = {k: v for k, v in sft_puzzle.items()
                              if any(t in k for t in puzzle_tasks)}
            puzzle_gains = compute_passk_gains(gspo_puzzle, sft_puzzle)

        # Per-task rows
        all_tasks = set(list(math_gains.keys()) + list(puzzle_gains.keys()))
        if all_tasks:
            for task in sorted(all_tasks):
                row = {
                    "checkpoint_id": ckpt_id,
                    "sft_base": sft_base,
                    "task": task,
                    "domain": "math" if math_task_filter in task else "puzzle",
                }
                for k in k_values:
                    row[f"pass@{k}_gain"] = (
                        math_gains.get(task, None)
                        if math_task_filter in task
                        else puzzle_gains.get(task, None)
                    )
                rows.append(row)

        # Aggregate row
        agg_row = {
            "checkpoint_id": ckpt_id,
            "sft_base": sft_base,
            "task": "_aggregate",
            "domain": "aggregate",
        }
        if math_gains:
            agg_row["math_pass32_gain"] = sum(math_gains.values()) / len(math_gains)
        if puzzle_gains:
            agg_row["puzzle_pass32_gain"] = sum(puzzle_gains.values()) / len(puzzle_gains)
        rows.append(agg_row)

    return pd.DataFrame(rows)
