# Functions copied from XXXX


def create_prompt_with_tulu_chat_format(messages, tokenizer, bos="<s>", eos="</s>", add_bos=True):
    formatted_text = ""
    for message in messages:
        if message["role"] == "system":
            formatted_text += "<|system|>\n" + message["content"] + "\n"
        elif message["role"] == "user":
            formatted_text += "<|user|>\n" + message["content"] + "\n"
        elif message["role"] == "assistant":
            formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
        else:
            raise ValueError(
                "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
                    message["role"]
                )
            )
    formatted_text += "<|assistant|>\n"
    formatted_text = bos + formatted_text if add_bos else formatted_text
    return formatted_text


# weirdness with olmo tokenizer means IP_ADDR is the eos token.
def create_prompt_with_olmo_chat_format(
    messages, tokenizer, bos="|||IP_ADDRESS|||", eos="|||IP_ADDRESS|||", add_bos=True
):
    formatted_text = ""
    for message in messages:
        if message["role"] == "system":
            formatted_text += "<|system|>\n" + message["content"] + "\n"
        elif message["role"] == "user":
            formatted_text += "<|user|>\n" + message["content"] + "\n"
        elif message["role"] == "assistant":
            formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
        else:
            raise ValueError(
                "Olmo chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
                    message["role"]
                )
            )
    formatted_text += "<|assistant|>\n"
    formatted_text = bos + formatted_text  # forcibly add bos
    return formatted_text


def create_prompt_with_llama2_chat_format(messages, tokenizer, bos="<s>", eos="</s>", add_bos=True):
    """
    This function is adapted from the official llama2 chat completion script:
    XXXX
    """
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    B_INST, E_INST = "[INST]", "[/INST]"
    formatted_text = ""
    # If you want to include system prompt, see this discussion for the template: XXXX
    # However, see here that removing the system prompt actually reduce the false refusal rates: XXXX
    if messages[0]["role"] == "system":
        assert (
            len(messages) >= 2 and messages[1]["role"] == "user"
        ), "LLaMa2 chat cannot start with a single system message."
        messages = [
            {
                "role": "user",
                "content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
            }
        ] + messages[2:]
    for message in messages:
        if message["role"] == "user":
            formatted_text += bos + f"{B_INST} {(message['content']).strip()} {E_INST}"
        elif message["role"] == "assistant":
            formatted_text += f" {(message['content'])} " + eos
        else:
            raise ValueError(
                "Llama2 chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
                    message["role"]
                )
            )
    # The llama2 chat template by default has a bos token at the start of each user message.
    # The next line removes the bos token if add_bos is False.
    formatted_text = formatted_text[len(bos) :] if not add_bos else formatted_text
    return formatted_text


def create_prompt_with_xwin_chat_format(messages, tokenizer, bos="<s>", eos="</s>", add_bos=True):
    """
    This function is adapted from the official xwin chat completion script:
    XXXX
    """
    formatted_text = "A chat between a curious user and an artificial intelligence assistant. "
    formatted_text += (
        "The assistant gives helpful, detailed, and polite answers to the user's questions. "
    )
    for message in messages:
        if message["role"] == "user":
            formatted_text += "USER: " + message["content"] + " "
        elif message["role"] == "assistant":
            formatted_text += "ASSISTANT: " + message["content"] + eos
    formatted_text += "ASSISTANT:"
    return formatted_text


def create_prompt_with_zephyr_chat_format(messages, tokenizer, bos="<s>", eos="</s>", add_bos=True):
    """
    This function is adapted from the official zephyr chat completion script:
    XXXX
    """
    formatted_text = ""
    # if messages[0]["role"] != "system":
    #     messages = [{
    #         "role": "system",
    #         "content": ""
    #     }] + messages

    for message in messages:
        if message["role"] == "system":
            formatted_text += "<|system|>\n" + message["content"] + eos + "\n"
        elif message["role"] == "user":
            formatted_text += "<|user|>\n" + message["content"] + eos + "\n"
        elif message["role"] == "assistant":
            formatted_text += "<|assistant|>\n" + message["content"] + eos + "\n"
        else:
            raise ValueError(
                "Zephyr chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
                    message["role"]
                )
            )
    formatted_text += "<|assistant|>\n"
    return formatted_text


CHAT_TEMPLATES = {
    "tulu": create_prompt_with_tulu_chat_format,
    "olmo": create_prompt_with_olmo_chat_format,
    "llama2": create_prompt_with_llama2_chat_format,
    "xwin": create_prompt_with_xwin_chat_format,
    "zephyr": create_prompt_with_zephyr_chat_format,
}
