from typing import List, Dict, Optional

DEFAULT_INSTRUCTION = "Let's think step by step."


TASK_STYLES: Dict[str, Dict[str, str]] = {
    "gsm8k": {
        "instruction": "Let's think step by step."
                       "Return the final integer in the form \\boxed{answer} on the last line.",
        "input_tag": "Question",
        "target_tag": "Answer",
    },
    "xsum": {
        "instruction": "Write a single-sentence abstractive summary that captures the main point. "
                       "Return exactly one sentence.",
        "input_tag": "Document",
        "target_tag": "Summary",
    },
    "fp": {
        "instruction": "Classify the sentiment as one of: positive, neutral, negative. "
                       "Answer with exactly one word.",
        "input_tag": "Sentence",
        "target_tag": "Label",
    },
    "gpqa": {
        "instruction": "Answer the multiple-choice question."
                       "Do minimal reasoning, then return the choice of the correct answer by selecting one of the options (e.g., '(A)', '(B)').",
        "input_tag": "Question",
        "target_tag": "Answer",
    },
    "date": {
        "instruction": "Answer the multiple-choice question."
                       "Do minimal reasoning, then return the choice of the correct answer by selecting one of the options (e.g., '(A)', '(B)').",
        "input_tag": "Question",
        "target_tag": "Answer",
    },
    "salient": {
        "instruction": "Answer the multiple-choice question."
                       "Do minimal reasoning, then return the choice of the correct answer by selecting one of the options (e.g., '(A)', '(B)').",
        "input_tag": "Question",
        "target_tag": "Answer",
    },
}




def render_example(ex: Dict[str, str], input_tag: str, target_tag: str) -> str:
    """Formulate ICL examples."""
    return f"{input_tag}:\n{ex['input'].strip()}\n{target_tag}:\n{ex['target'].strip()}"

def build_prompt(
    task: str,
    *,
    instruction: Optional[str] = None,
    query: str,
    examples: Optional[List[Dict[str, str]]] = None,  # few-shot: [{'input','target'}, ...]
    examples_header: str = "### Examples",
    task_header: str = "### Now, solve the following task.",
    add_output_tag: bool = True,                  # Add "Output:" prompt at the end
) -> str:
    """Formulate Prompt."""
    key = task.strip().lower()
    if key not in TASK_STYLES:
        raise ValueError(f"Unknown task '{task}'. Known: {list(TASK_STYLES.keys())}")

    style = TASK_STYLES[key]
    input_tag = style["input_tag"]
    target_tag = style["target_tag"]
    instr = (instruction or style.get("instruction", "")).strip()

    blocks: List[str] = []
    if instr:
        blocks.append(instr)

    if examples:
        ex_blocks = [render_example(ex, input_tag, target_tag) for ex in examples]
        if ex_blocks:
            blocks.append(examples_header)
            blocks.append("\n\n".join(ex_blocks))

    blocks.append(task_header)
    q = f"{input_tag}:\n{query.strip()}"
    if add_output_tag:
        q += f"\n\n{target_tag}:"
    blocks.append(q)

    return "\n\n".join(blocks)