B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

def get_prompt_template(prompt_template_style="base", prompt_template_pth=None):
    response_prefix = None

    if prompt_template_style == 'gsm8k':
        PROMPT_TEMPLATE = "Question: %s\nAnswer:"
    elif prompt_template_style.startswith("geo880"):
        PROMPT_TEMPLATE = "Question: %s\nAnswer:"
    elif 'sciq' in prompt_template_style  or 'squad' in prompt_template_style:
        PROMPT_TEMPLATE = "Support: {support}\nQuestion: {question}\nAnswer:"
    else:
        raise ValueError("Invalid prompt template style.")
    
    if response_prefix is not None:
        PROMPT_TEMPLATE = PROMPT_TEMPLATE + " " + response_prefix
    
    return PROMPT_TEMPLATE

def apply_prompt_template(prompt_template_style='base', dataset=None, tokenizer=None, prefix="", return_dialogs=False):
    
    PROMPT_TEMPLATE = get_prompt_template(prompt_template_style)

    
    dialogs = []
    chats = []
    
    for prompt in dataset:
        if isinstance(prompt, dict):
            prompt = PROMPT_TEMPLATE.format(**prompt)
        else:
            prompt = PROMPT_TEMPLATE % (prefix + prompt)
        dialogs.append(prompt)
        chats.append(tokenizer.encode(prompt))
    
    if return_dialogs:
        return chats, dialogs
    else:
        return chats

def apply_icl_prompt(test_dataset_inputs, train_datsaet_inputs, train_dataset_outputs, idx_mat, prompt_template_style='gsm8k'):
    icl_prompts = []
    for i, prompt in enumerate(test_dataset_inputs):
        idx_list = idx_mat[i]
        icl_examples = [train_datsaet_inputs[i] + " " + train_dataset_outputs[i] for i in idx_list]
        icl_prompt = "\n\n".join(icl_examples)
        if prompt_template_style=='gsm8k':
            icl_prompt_whole = B_INST + " " + icl_prompt + "\n\nLet's think step by step. You need to solve the final question and answer in the format: \n#### {result}\n" + prompt + " " + E_INST
        elif prompt_template_style.startswith("geo880"):
            icl_prompt_whole = icl_prompt + '\n\n' + prompt
        elif 'sciq' in prompt_template_style or 'squad' in prompt_template_style:
            icl_prompt_whole = icl_prompt + '\n\n' + prompt
        else:
            raise ValueError("Invalid prompt template style.")
        icl_prompts.append(icl_prompt_whole)
    return icl_prompts