"""
Pattern puzzle lm-eval task utilities.

Supports pattern ASCII-format datasets (3x3, 4x4, 5x5, etc.).
Reuses extraction/normalization from reward_function/generic_puzzle.py.
"""

import os
import sys
import logging

# Add project root to sys.path for reward_function imports
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from reward_function.generic_puzzle import extract_answer, normalize_grid

eval_logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Template loading (once at import time)
# ---------------------------------------------------------------------------
_TEMPLATE_PATH = os.path.join(_PROJECT_ROOT, "prompts", "pattern_formatted.txt")
with open(_TEMPLATE_PATH) as _f:
    _TEMPLATE = _f.read()


# ---------------------------------------------------------------------------
# lm_eval interface
# ---------------------------------------------------------------------------
def doc_to_text(doc):
    """Return full formatted prompt with puzzle inserted into template."""
    return _TEMPLATE.replace("{}", str(doc["problem"]))


def process_results(doc, results):
    """Score first response (lm_eval convention)."""
    return {"exact_match": _score_single(results[0], doc)}


def _score_single(model_output: str, doc: dict) -> int:
    """Score a single response for pass@k evaluation. Returns 1 if exact match, 0 otherwise."""
    extracted = extract_answer(model_output)
    if extracted is None:
        return 0

    extracted_norm = normalize_grid(extracted)
    gt_norm = normalize_grid(str(doc["solution"]))

    return 1 if extracted_norm == gt_norm else 0
