"""Shared scoring logic for puzzle lm_eval tasks.

Delegates to existing reward_function/ modules for answer extraction
and grid normalization.
"""

import json
import os
import sys

_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from reward_function.bridges import extract_answer as bridges_extract_answer
from reward_function.bridges import normalize_grid_format as bridges_normalize_grid
from reward_function.generic_puzzle import extract_answer as generic_extract_answer
from reward_function.generic_puzzle import normalize_grid as generic_normalize_grid


def _normalize_solution(solution):
    """Normalize solution from dataset (list-of-lists, JSON string, or plain str)."""
    if isinstance(solution, list):
        return "\n".join(",".join(str(c) for c in row) for row in solution)
    s = str(solution).strip()
    try:
        parsed = json.loads(s)
        if isinstance(parsed, list) and parsed and isinstance(parsed[0], list):
            return "\n".join(",".join(str(c) for c in row) for row in parsed)
    except (json.JSONDecodeError, TypeError, ValueError):
        pass
    return s


def make_scorer(puzzle_type: str = "generic"):
    """Return (process_results, _score_single) for a puzzle type.

    Args:
        puzzle_type: "bridges" uses bridges-specific extraction/normalization,
                     anything else uses generic_puzzle.
    """
    if puzzle_type == "bridges":
        _extract, _normalize = bridges_extract_answer, bridges_normalize_grid
    else:
        _extract, _normalize = generic_extract_answer, generic_normalize_grid

    def _score_single(model_output: str, doc: dict) -> int:
        """Score a single response. Returns 1 if exact match, 0 otherwise."""
        extracted = _extract(model_output)
        if extracted is None:
            return 0
        extracted_norm = _normalize(extracted)
        gt_norm = _normalize_solution(doc["solution"])
        return 1 if extracted_norm == gt_norm else 0

    def process_results(doc, results):
        """lm_eval process_results interface. Scores first response."""
        return {"exact_match": _score_single(results[0], doc)}

    return process_results, _score_single
