import torch
from chat_template import ChatTemplate


def tokenize_collate_func(
    batch,
    tokenizer=None,
    max_length=None,
    chat_template=None,
    dialog_format="messages,system,user,assistant,role,content",
):
    if chat_template is not None:
        # convert the dialog format to the format that the tokenizer can understand
        if dialog_format != "messages,system,user,assistant,role,content":
            role_mapping = dict()
            dialog_format_split = dialog_format.split(",")
            for _from, _to in zip(
                dialog_format_split[1:-2], ["system", "user", "assistant"]
            ):
                role_mapping[_from] = _to

            dialog_key = dialog_format_split[0]
            content_key = dialog_format_split[-1]
            role_key = dialog_format_split[-2]

            new_batch = []
            for dialog in batch:
                new_dialog = {"messages": []}
                for message in dialog[dialog_key]:
                    new_message = dict()
                    new_message["role"] = role_mapping[message[role_key]]
                    new_message["content"] = message[content_key]
                    new_dialog["messages"].append(new_message)
                new_batch.append(new_dialog)
            batch = new_batch

        texts = tokenizer.apply_chat_template(
            list(map(lambda x: x["messages"], batch)),
            tokenize=False,
            add_generation_prompt=False,
            chat_template=ChatTemplate[chat_template].value,
        )
    else:
        texts = batch["text"]
    inputs = tokenizer(
        texts,
        padding="longest",
        truncation=True,
        max_length=max_length,
        add_special_tokens=False,
        return_tensors="pt",
    )

    # NOTE: required for the layer input
    inputs["position_ids"] = torch.arange(inputs["input_ids"].shape[1]).unsqueeze(0)

    return inputs


def concat_turns_collate_func(
    batch,
    tokenizer=None,
    max_length=None,
    dialog_format="messages,system,user,assistant,role,content",
    padding="longest",
    mutual_add_eos_token=False,
    tokenize=True,
):
    parts = dialog_format.split(",")
    texts = list(
        map(
            lambda example: " ".join(
                map(lambda message: message[parts[-1]], example[parts[0]])
            ),
            batch,
        )
    )

    if mutual_add_eos_token:
        texts = list(map(lambda x: x + tokenizer.eos_token, texts))

    if tokenize:
        return tokenizer(
            texts,
            padding=padding,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )
    else:
        return texts
