"""
Bridges puzzle lm-eval task utilities.

Supports bridges intformat_json datasets (7x7, 8x8, etc.).
Reuses extraction/normalization from reward_function/bridges.py.
"""

import json
import os
import sys
import logging

# Add project root to sys.path for reward_function imports
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from reward_function.bridges import extract_answer, normalize_grid_format

eval_logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Template loading (once at import time)
# ---------------------------------------------------------------------------
_TEMPLATE_PATH = os.path.join(_PROJECT_ROOT, "prompts", "bridges_intformat.txt")
with open(_TEMPLATE_PATH) as _f:
    _TEMPLATE = _f.read()


# ---------------------------------------------------------------------------
# Grid formatting
# ---------------------------------------------------------------------------
def _format_grid(problem):
    """Convert problem field (list-of-lists or JSON string) to comma-separated rows."""
    if isinstance(problem, str):
        problem = json.loads(problem)
    return "\n".join(",".join(str(c) for c in row) for row in problem)


def _normalize_solution(solution):
    """Normalize solution field to canonical comma-separated rows."""
    if isinstance(solution, list):
        return "\n".join(",".join(str(c) for c in row) for row in solution)
    return normalize_grid_format(str(solution))


# ---------------------------------------------------------------------------
# lm_eval interface
# ---------------------------------------------------------------------------
def doc_to_text(doc):
    """Return full formatted prompt with grid inserted into template."""
    grid_str = _format_grid(doc["problem"])
    return _TEMPLATE.replace("{}", grid_str)


def process_results(doc, results):
    """Score first response (lm_eval convention)."""
    return {"exact_match": _score_single(results[0], doc)}


def _score_single(model_output: str, doc: dict) -> int:
    """Score a single response for pass@k evaluation. Returns 1 if exact match, 0 otherwise."""
    extracted = extract_answer(model_output)
    if extracted is None:
        return 0

    # Normalize extracted answer (handles JSON arrays, code blocks, etc.)
    extracted_norm = normalize_grid_format(extracted)

    # Normalize ground truth
    gt_norm = _normalize_solution(doc["solution"])

    return 1 if extracted_norm == gt_norm else 0
