# cirbench/utils/prompt.py
from __future__ import annotations
from jinja2 import Environment
def _jinja_env() -> Environment:
    # Use non-default block/comment delimiters to prevent literal `{%` in prompts
    # (e.g., "MemoryPhi({%0,1}, ...)") from being parsed as Jinja blocks.
    return Environment(
        autoescape=False,
        trim_blocks=False,
        lstrip_blocks=False,
        block_start_string='[%',
        block_end_string='%]',
        comment_start_string='[#',
        comment_end_string='#]',
        variable_start_string='{{',
        variable_end_string='}}',
    )
from pathlib import Path
from typing import Dict, List, Optional
from ..cfg import CIRBenchConfig
import re

def _read_shot(p: Path) -> str:
    return p.read_text(encoding="utf-8")


def _compat_vars(s: str) -> str:
    s = re.sub(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", r"{{ \1 }}", s)
    s = re.sub(r"\$([A-Za-z_][A-Za-z0-9_]*)\b", r"{{ \1 }}", s)
    return s


def build_prompt(cfg: CIRBenchConfig, track: str, base_prompt_txt: str,
                 vars: Dict, case_id: Optional[str] = None) -> str:
    pconf = cfg.prompt_cfg
    tpl_txt = base_prompt_txt or pconf.track_templates.get(track) or pconf.default_template

    shots: List[str] = []
    if pconf.shots_dir and case_id:
        sp = Path(pconf.shots_dir)/f"{case_id}.shot.txt"
        if sp.exists():
            shots.append(_read_shot(sp))
    shots.extend(pconf.shots or [])

    blocks = ""
    if pconf.enforce_blocks:
        blocks = (
            # "\nReturn ONLY the two delimited blocks if requested.\n"
            # "Start a structured JSON block as <CIR_JSON>{...}</CIR_JSON>.\n"
            # "If code is requested, return it in <IR_OUT>...</IR_OUT>.\n"
            ""
        )

    tpl_txt = _compat_vars(tpl_txt)
    shots_txt = _compat_vars("\n\n".join(shots))

    env = _jinja_env()
    tpl = env.from_string(tpl_txt + blocks)
    content = tpl.render(**vars, SHOTS=shots_txt)
    return content