#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.

import logging
from logging import Logger
from typing import Dict, Union, Optional

import transformers
from transformers.testing_utils import CaptureLogger
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from tokenize_data.conversation_template import ConversationTemplate
CONVERSATION_ROLE_NAMES = {
    "system": "system",
    "user": "user",
    "assistant": "assistant",
    "function": "function",
    "observation": "observation"
}


logger = logging.getLogger(__name__)
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")


def blocking(
    token_dict: Dict, 
    block_size: int, 
    model_max_length: int,
    pad_token_id: int,
    padding_side: str,
    truncation_side: str='right',
) -> Dict:
    num_example = len(token_dict[list(token_dict.keys())[0]])
    for i in range(num_example):
        max_length = min(block_size, model_max_length)
        pad_length = max_length - len(token_dict["input_ids"][i])
        if pad_length < 0:
            # Truncates too long samples
            for key in ["input_ids", "attention_mask", "labels"]:
                if truncation_side == 'right':
                    token_dict[key][i] = token_dict[key][i][:max_length]
                elif truncation_side == 'left':
                    token_dict[key][i] = token_dict[key][i][-max_length:]
                else:
                    raise ValueError(
                        f"truncation_side should be either 'right' or 'left', got {truncation_side}"
                    )
        else:
            if padding_side == 'right':
                # Pads too short samples
                token_dict["input_ids"][i].extend(
                    [pad_token_id for _ in range(pad_length)]
                )
                token_dict["attention_mask"][i].extend(
                    [0 for _ in range(pad_length)]
                )
                token_dict["labels"][i].extend(
                    [-100 for _ in range(pad_length)]
                )
            elif padding_side == 'left':
                # Pads too short samples
                token_dict["input_ids"][i] = (
                    [pad_token_id for _ in range(pad_length)] + token_dict["input_ids"][i]
                )
                token_dict["attention_mask"][i] = (
                    [0 for _ in range(pad_length)] + token_dict["attention_mask"][i]
                )
                token_dict["labels"][i] = (
                    [-100 for _ in range(pad_length)] + token_dict["labels"][i]
                )
            else:
                raise ValueError(
                    f"padding_side should be either 'right' or 'left', got {padding_side}"
                )
        
    return token_dict


def tokenize_function(
    examples, 
    disable_group_texts: bool,
    block_size: int,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    column_names,
    label_columns,
    tokenized_column_order,
    add_special_tokens,
    use_truncation,
) -> Dict:
    """Handels text_only and text2text datasets tokenization
    """
    num_example = len(examples[column_names[0]])
    token_dict = {
        "input_ids": [[] for _ in range(num_example)],
        "attention_mask": [[] for _ in range(num_example)],
        "labels": [[] for _ in range(num_example)],
    }
    with CaptureLogger(tok_logger) as cl:
        for column_name in tokenized_column_order:
            encoding = tokenizer(
                examples[column_name],
                add_special_tokens=add_special_tokens,
                truncation=use_truncation,
            )

            if column_name in label_columns:
                labels = encoding["input_ids"].copy()
            else:
                labels = [
                    [-100] * len(encoding["input_ids"][i])
                    for i in range(num_example)
                ]

            for i in range(num_example):
                token_dict["input_ids"][i].extend(
                    encoding["input_ids"][i]
                )
                token_dict["attention_mask"][i].extend(
                    encoding["attention_mask"][i]
                )
                token_dict["labels"][i].extend(labels[i])
                
    if disable_group_texts:
        token_dict = blocking(
            token_dict=token_dict,
            block_size=block_size,
            model_max_length=tokenizer.model_max_length,
            pad_token_id=tokenizer.pad_token_id,
            padding_side=tokenizer.padding_side,
            truncation_side=tokenizer.truncation_side,
        )

    # clm input could be much much longer than block_size
    if "Token indices sequence length is longer than the" in cl.out:
        tok_logger.warning(
            "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
            " before being passed to the model."
        )
    return token_dict
    

def conversation_tokenize_function(
    examples, 
    disable_group_texts: bool,
    block_size: int,
    train_on_prompt: bool,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 
    column_names,
    conversation_template: Union[ConversationTemplate, str]
) -> Dict:
    """Handels conversation datasets tokenization
    """
    num_example = len(examples[column_names[0]])
    token_dict = {
        "input_ids": [[] for _ in range(num_example)],
        "attention_mask": [[] for _ in range(num_example)],
        "labels": [[] for _ in range(num_example)],
    }
    with CaptureLogger(tok_logger) as cl:
        for i in range(len(examples["messages"])):
            messages = examples["messages"][i]
            system = examples.get("system", [None] * num_example)[i]
            tools = examples.get("tools", [None] * num_example)[i]
            
            if isinstance(conversation_template, str): # jinja template
                if system is not None:
                    conversation = [{"role": "system", "content": system}]
                else:
                    conversation = []
                conversation.extend(messages)
                # print(conversation)
                encoded_conversation = tokenizer.apply_chat_template(
                    conversation=conversation,
                    tools=tools,
                    chat_template=conversation_template,
                    return_assistant_tokens_mask=True,
                    return_dict=True,
                )
                
                if train_on_prompt:
                    labels = encoded_conversation["input_ids"]
                else:
                    labels = [
                        encoded_conversation["input_ids"][index] if mask == 1 else -100 
                        for index, mask in enumerate(encoded_conversation["assistant_masks"])
                    ]
                    
                token_dict['input_ids'][i].extend(encoded_conversation['input_ids'])
                token_dict['attention_mask'][i].extend(encoded_conversation['attention_mask'])
                token_dict['labels'][i].extend(labels)
            
            else: # lmflow `conversation_template`
                if len(messages) < 2 or messages[0]['role'] != CONVERSATION_ROLE_NAMES['user']:
                    tok_logger.warning(
                        "Invalid instance encountered. Either the conversation has less than "
                        "one round or the first message is not from the user."
                    )
                    continue
            
                if len(messages) % 2 != 0:
                    logger.warning(
                        "The number of messages is not even, the last message will be ignored."
                    )
                    messages = messages[:-1]
                    
                encoded_conversation = conversation_template.encode_conversation(
                    tokenizer=tokenizer,
                    messages=messages,
                    system=system,
                    tools=tools,
                )

                input_ids, labels = [], []
                for turn_idx, (user_input, assistant_result) in enumerate(encoded_conversation):
                    input_ids += user_input + assistant_result
                    
                    if train_on_prompt:
                        labels += user_input + assistant_result
                    else:
                        labels += [-100] * len(user_input) + assistant_result
                    
                token_dict["input_ids"][i].extend(input_ids)
                token_dict["attention_mask"][i].extend([1] * len(input_ids))
                token_dict["labels"][i].extend(labels)

    if disable_group_texts:
        token_dict = blocking(
            token_dict=token_dict,
            block_size=block_size,
            model_max_length=tokenizer.model_max_length,
            pad_token_id=tokenizer.pad_token_id,
            padding_side=tokenizer.padding_side,
            truncation_side=tokenizer.truncation_side,
        )

    # clm input could be much much longer than block_size
    if "Token indices sequence length is longer than the" in cl.out:
        tok_logger.warning(
            "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
            " before being passed to the model."
        )
    return token_dict