"""Countdown puzzle lm_eval task utilities."""
import os
import sys

_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from reward_function.countdown import compute_score, extract_answer


# System prompt matching what training uses
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the "
    "Assistant solves it. The assistant first thinks about the reasoning process in "
    "the mind and then provides the user with the answer. The reasoning process and "
    "answer are enclosed within <think> </think> and <answer> </answer> tags, "
    "respectively, i.e., <think> reasoning process here </think>\n"
    "<answer> answer here </answer>."
)


def doc_to_text(doc):
    """Format a Countdown doc into the prompt the model sees."""
    nums = doc["nums"]
    target = doc["target"]
    # Convert numpy to list if needed
    if hasattr(nums, 'tolist'):
        nums = nums.tolist()
    user_msg = (
        f"Using the numbers {nums}, create an equation that equals {target}. "
        f"You can use basic arithmetic operations (+, -, *, /) and parentheses. "
        f"Each number must be used exactly once."
    )
    return f"{SYSTEM_PROMPT}\n\nUser: {user_msg}\n\nAssistant:"


def _score_single(model_output: str, doc: dict) -> int:
    """Score a single response. Returns 1 if correct, 0 otherwise.

    Compatible with compute_pass_at_k.py's _resolve_scorer interface.
    """
    target = int(doc["target"])
    nums = doc["nums"]
    if hasattr(nums, 'tolist'):
        nums = nums.tolist()

    extra_info = {"nums": nums, "target": target}
    score = compute_score(model_output, str(target), extra_info, method="exact")
    acc = score if isinstance(score, (int, float)) else score.get("acc", 0)
    return int(acc >= 1.0)


def process_results(doc, results):
    """Score model output against ground truth (lm_eval interface)."""
    return {"exact_match": _score_single(results[0], doc)}
