"""Shared template loading and grid formatting for puzzle lm_eval tasks."""

import json
import os

_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
_template_cache = {}


def load_template(template_name: str) -> str:
    """Load a prompt template from prompts/ dir. Cached after first load."""
    if template_name not in _template_cache:
        path = os.path.join(_PROJECT_ROOT, "prompts", template_name)
        with open(path) as f:
            _template_cache[template_name] = f.read()
    return _template_cache[template_name]


def format_grid(problem) -> str:
    """Convert problem field to comma-separated grid string.

    Handles:
    - Python list-of-lists (HF datasets auto-parse JSON)
    - JSON string of list-of-lists — passed through as-is for parity with
      eval_lora_checkpoints.py (the model sees the raw JSON in the prompt)
    - Plain string (ASCII/formatted datasets)
    """
    if isinstance(problem, list):
        return "\n".join(",".join(str(c) for c in row) for row in problem)
    return str(problem)


def make_doc_to_text(template_name: str):
    """Factory: return a doc_to_text function for a given prompt template."""
    def doc_to_text(doc):
        template = load_template(template_name)
        grid_str = format_grid(doc["problem"])
        try:
            return template.format(problem=grid_str)
        except (KeyError, IndexError):
            return template.replace("{}", grid_str)
    return doc_to_text
