import regex
import random

DEFAULT_PROMPT_FORMAT = 0
FORMATS = {
    "zero_shot": [
        {
            "prompt": "You can answer only with \"true\" or \"false\". Is the fact true? Fact: {fact} Answer: ",
            "label": "true",
            "negative": "false"
        },
        {
            "prompt": "You can answer only with \"yes\" or \"no\". Is the fact true? Fact: {fact} Answer: ",
            "label": "yes",
            "negative": "no"
        },
    ],
    "few_shot": [
        {
            "prompt": "Fact: the earth is round. Label: true.\nFact: the sun is cold. Label: false.\nFact: {fact} Label:",
            "label": "true",
            "negative": "false"
        },
        {
            "prompt": "Fact: the earth is round. Label: yes.\nFact: the sun is cold. Label: no.\nFact: {fact} Label:",
            "label": "yes",
            "negative": "no"
        }
    ]
}


def preprocess_batch(batch, prompt_format, prompt_idx):
    """ Formats input facts with the structured prompt with instructions """
    formatted_batch = [prompt_format[prompt_idx]["prompt"].format(fact=f) for f in batch]
    return formatted_batch

def postprocess_answers(prompts, answers, prompt_format, prompt_idx):
    """ Convert decoder-only output (prompt+answer) to the binary truth assignment given a prompt format """
    answ_without_prefix = [s.replace(prompts[idx], "") for idx, s in enumerate(answers)]
    labels = [int(prompt_format[prompt_idx]["label"].lower() in answ.lower()) for answ in answ_without_prefix]
    return answ_without_prefix, labels

def prompt_answer(type:str, text:str) -> str:
    if type == "decoder":
        return text
    else:
        return f"$answer$ = {text} ;"