"""Load lm_eval JSONL samples and score responses.

Provides unified loading for both puzzle and math evaluation results,
reusing the scorer infrastructure from compute_pass_at_k.py.
"""
from __future__ import annotations

import json
import re
import sys
from dataclasses import dataclass, field
from pathlib import Path

# Import scoring infrastructure from existing code
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(_PROJECT_ROOT / "scripts" / "evals"))
from compute_pass_at_k import (
    _extract_task_name,
    _resolve_scorer,
    is_correct,
    load_and_score,
)


@dataclass
class TraceSample:
    """A single sampled trace (one response to one prompt)."""
    checkpoint_id: str
    task_name: str
    doc_id: int
    trace_id: int  # index within the repeats for this doc
    response: str
    correct: bool
    doc: dict = field(repr=False)


@dataclass
class PromptSamples:
    """All traces for a single prompt under one checkpoint."""
    checkpoint_id: str
    task_name: str
    doc_id: int
    traces: list[TraceSample]
    doc: dict = field(repr=False)

    @property
    def n_correct(self) -> int:
        return sum(t.correct for t in self.traces)

    @property
    def n_total(self) -> int:
        return len(self.traces)

    @property
    def correct_mask(self) -> list[bool]:
        return [t.correct for t in self.traces]


def load_samples(jsonl_path: str | Path, checkpoint_id: str = "") -> list[PromptSamples]:
    """Load one lm_eval samples JSONL file into a list of PromptSamples.

    Each line in the JSONL has one doc with all repeats in resps[0].
    """
    jsonl_path = Path(jsonl_path)
    task_name = _extract_task_name(str(jsonl_path))
    scorer = _resolve_scorer(task_name)

    prompts: list[PromptSamples] = []
    with open(jsonl_path) as f:
        for line in f:
            d = json.loads(line)
            doc_id = d.get("doc_id", d.get("idx", 0))
            doc = d.get("doc", {})
            target = d.get("target", "")
            resps = d.get("resps", [[]])
            responses = resps[0] if resps and isinstance(resps[0], list) else resps

            traces = []
            for i, resp in enumerate(responses):
                if scorer:
                    ok = bool(scorer(resp, doc))
                else:
                    ok = is_correct(resp, target)
                traces.append(TraceSample(
                    checkpoint_id=checkpoint_id,
                    task_name=task_name,
                    doc_id=doc_id,
                    trace_id=i,
                    response=resp,
                    correct=ok,
                    doc=doc,
                ))

            prompts.append(PromptSamples(
                checkpoint_id=checkpoint_id,
                task_name=task_name,
                doc_id=doc_id,
                traces=traces,
                doc=doc,
            ))

    return prompts


def discover_jsonl_files(results_dir: str | Path) -> list[Path]:
    """Find all samples_*.jsonl files under a results directory."""
    return sorted(Path(results_dir).rglob("samples_*.jsonl"))


def load_results_dir(
    results_dir: str | Path,
    checkpoint_id: str = "",
) -> dict[str, list[PromptSamples]]:
    """Load all samples from a results directory.

    Returns dict keyed by task_name -> list[PromptSamples].
    """
    results: dict[str, list[PromptSamples]] = {}
    for jsonl_path in discover_jsonl_files(results_dir):
        task_name = _extract_task_name(str(jsonl_path))
        prompts = load_samples(jsonl_path, checkpoint_id=checkpoint_id)
        if prompts:
            results[task_name] = prompts
    return results
