# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import json
from typing import List, Literal, TypedDict


Role = Literal["user", "assistant"]


class Message(TypedDict):
    role: Role
    content: str


Dialog = List[Message]

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
def format_tokens(dialogs, tokenizer):
    prompt_tokens = []
    for dialog in dialogs:
        if dialog[0]["role"] == "system":
            dialog = [
            {
                "role": dialog[1]["role"],
                "content": B_SYS
                + dialog[0]["content"]
                + E_SYS
                + dialog[1]["content"],
            }
        ] + dialog[2:]
        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
            [msg["role"] == "assistant" for msg in dialog[1::2]]
        ), (
            "model only supports 'system','user' and 'assistant' roles, "
            "starting with user and alternating (u/a/u/a/u...)"
        )
        """
        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
        Here, we are adding it manually.
        """
        dialog_tokens: List[int] = sum(
            [
                tokenizer.encode(
                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
                ) + [tokenizer.eos_token_id]
                for prompt, answer in zip(dialog[::2], dialog[1::2])
            ],
            [],
        )
        assert (
            dialog[-1]["role"] == "user"
        ), f"Last message must be from user, got {dialog[-1]['role']}"
        dialog_tokens += tokenizer.encode(
            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
        )
        prompt_tokens.append(dialog_tokens)
    return prompt_tokens

def extract_dialogs(dialogs):
    extracted_dialogs = []

    for dialog in dialogs:
        if dialog[0]["role"] == "system":
            dialog = [{
                "role"   : dialog[1]["role"],
                "content": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"],
            }] + dialog[2:]

        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
            [msg["role"] == "assistant" for msg in dialog[1::2]]
        ), (
            "model only supports 'system','user' and 'assistant' roles, "
            "starting with user and alternating (u/a/u/a/u...)"
        )

        # Prepare the text for the entire dialogue before encoding
        dialog_text = ""
        for prompt, answer in zip(dialog[::2], dialog[1::2]):
            dialog_text += f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} " + tokenizer.eos_token

        assert dialog[-1]["role"] == "user", f"Last message must be from user, got {dialog[-1]['role']}"
        dialog_text += f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}"

        # Add the prepared text to the batch
        extracted_dialogs.append(dialog_text)

    return extracted_dialogs

def batch_format_tokens(dialogs, tokenizer, batch_size):
    # Initialize the batched prompt tokens list
    batched_prompt_tokens = []

    # Initialize a list to collect a batch of dialog strings
    batch_dialogs = []

    for dialog in dialogs:
        if dialog[0]["role"] == "system":
            dialog = [{
                "role"   : dialog[1]["role"],
                "content": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"],
            }] + dialog[2:]

        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
            [msg["role"] == "assistant" for msg in dialog[1::2]]
        ), (
            "model only supports 'system','user' and 'assistant' roles, "
            "starting with user and alternating (u/a/u/a/u...)"
        )

        # Prepare the text for the entire dialogue before encoding
        dialog_text = ""
        for prompt, answer in zip(dialog[::2], dialog[1::2]):
            dialog_text += f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} " + tokenizer.eos_token

        assert dialog[-1]["role"] == "user", f"Last message must be from user, got {dialog[-1]['role']}"
        dialog_text += f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}"

        # Add the prepared text to the batch
        batch_dialogs.append(dialog_text)

        # When we've reached the batch size, encode the batch and add to batched_prompt_tokens
        if len(batch_dialogs) == batch_size:
            encoded_batch = tokenizer(batch_dialogs, add_special_tokens=True, return_tensors="pt", padding=True).to('cuda')
            batched_prompt_tokens.append(encoded_batch)
            batch_dialogs = []  # Reset the batch

    # Encode and add any remaining dialogs that didn't make a full batch
    if batch_dialogs:
        encoded_batch = tokenizer(batch_dialogs, add_special_tokens=True, return_tensors="pt", padding=True).to('cuda')
        batched_prompt_tokens.append(encoded_batch)

    return batched_prompt_tokens

def read_dialogs_from_file(file_path):
    with open(file_path, 'r') as file:
        dialogs = json.load(file)
    return dialogs
