import torch
from functools import partial
from src.utils.logging_utils import get_logger

logger = get_logger(name=__name__)


# Reference: Open Instruct
CHAT_TEMPLATES = {
    "simple_concat_with_space": (
        "{% for message in messages %}"
        "{{ ' ' if not loop.first else '' }}"
        "{{ message['content'] }}"
        "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}"
        "{% endfor %}"
    ),
    "simple_concat_with_new_line": (
        "{% for message in messages %}"
        "{{ '\n' if not loop.first else '' }}"
        "{{ message['content'] }}"
        "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}"
        "{% endfor %}"
    ),
    "simple_chat": (
        "{% for message in messages %}"
        "{{ '\n\n' if not loop.first else '' }}"
        "{{ message['role'].capitalize() + ': ' + message['content'] }}"
        "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}"
        "{% endfor %}"
    ),
}

def encode_sft_example(example, tokenizer, max_seq_length):
    """
    This function encodes a single example into a format that can be used for sft training.
    Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
    We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
    """
    messages = example["messages"]
    if len(messages) == 0:
        raise ValueError("messages field is empty.")
    input_ids = tokenizer.apply_chat_template(
        conversation=messages,
        tokenize=True,
        return_tensors="pt",
        padding=False,
        truncation=True if max_seq_length else False,
        max_length=max_seq_length,
        add_generation_prompt=False,
    )
    labels = input_ids.clone()
    # mask the non-assistant part for avoiding loss
    for message_idx, message in enumerate(messages):
        if message["role"] != "assistant":
            # we calculate the start index of this non-assistant message
            if message_idx == 0:
                message_start_idx = 0
            else:
                message_start_idx = tokenizer.apply_chat_template(
                    conversation=messages[:message_idx],  # here marks the end of the previous messages
                    tokenize=True,
                    return_tensors="pt",
                    padding=False,
                    truncation=True if max_seq_length else False,
                    max_length=max_seq_length,
                    add_generation_prompt=False,
                ).shape[1]
            # next, we calculate the end index of this non-assistant message
            if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
                # for intermediate messages that follow with an assistant message, we need to
                # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
                # (e.g., `<|assistant|>`)
                message_end_idx = tokenizer.apply_chat_template(
                    conversation=messages[: message_idx + 1],
                    tokenize=True,
                    return_tensors="pt",
                    padding=False,
                    truncation=True if max_seq_length else False,
                    max_length=max_seq_length,
                    add_generation_prompt=True,
                ).shape[1]
            else:
                # for the last message or the message that doesn't follow with an assistant message,
                # we don't need to add the assistant generation prefix
                message_end_idx = tokenizer.apply_chat_template(
                    conversation=messages[: message_idx + 1],
                    tokenize=True,
                    return_tensors="pt",
                    padding=False,
                    truncation=True if max_seq_length else False,
                    max_length=max_seq_length,
                    add_generation_prompt=False,
                ).shape[1]
            # set the label to -100 for the non-assistant part
            labels[:, message_start_idx:message_end_idx] = -100
            if max_seq_length and message_end_idx >= max_seq_length:
                break
    attention_mask = torch.ones_like(input_ids)
    return {
        "input_ids": input_ids.flatten(),
        "labels": labels.flatten(),
        "attention_mask": attention_mask.flatten(),
    }

def create_chat_tokenizer(*, tokenizer_name: str, chat_template: str | None = None):
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    logger.info(f"Loaded tokenizer: {tokenizer_name}") 

    if tokenizer.chat_template is None and chat_template is not None:
        assert chat_template is not None, "chat_template must be provided if tokenizer.chat_template is not provided."
        tokenizer.chat_template = CHAT_TEMPLATES[chat_template]
        logger.info(f"Loaded chat_template: {chat_template}")

    
    assert tokenizer.chat_template is not None, "chat_template must be provided."

    return tokenizer



def encode_sft_dataset(
    *,
    dataset,
    tokenizer,
    max_len: int | None,
    num_proc: int = 1,
    overwrite_cache: bool = False,
    cols_to_keep: list[str] = None,
    keep_all_columns: bool = False,
    tokenizer_name: str | None = None, 
    chat_template: str | None = None, 
    **kwargs,
):
    
    if tokenizer is None:
        assert tokenizer_name is not None, "tokenizer_name must be provided if tokenizer is not provided."
        tokenizer = create_chat_tokenizer(tokenizer_name=tokenizer_name, chat_template=chat_template)

    if keep_all_columns:
        assert not cols_to_keep, "cols_to_keep must be empty if keep_all_columns is True."
        cols_to_remove = []
    else:
        cols_to_keep = cols_to_keep or []
        cols_to_remove = [
            name for name in dataset.column_names if not name.startswith("_") and name not in (["input_ids", "labels", "attention_mask"] + cols_to_keep)
        ]

    dataset = dataset.map(
        partial(
            encode_sft_example,
            tokenizer=tokenizer,
            max_seq_length=max_len,
        ),
        batched=False,
        num_proc=num_proc,
        load_from_cache_file=not overwrite_cache,
        remove_columns=cols_to_remove,
        desc="Tokenizing and reformatting instruction data",
    )
    dataset.set_format(type="pt")
    dataset = dataset.filter(
        lambda example: (example["labels"] != -100).any(),
        num_proc=num_proc,
        desc="Filtering out examples with no assistant message",
    )
    return dataset
