from transformers import (
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)


def initialize_tokenizer(
    model: PreTrainedModel,
    model_name: str,
) -> PreTrainedTokenizerBase:
    """ Initializes the tokenizer.

    Args:
        model (PreTrainedModel): The model.
        model_name (str): The model name.

    Returns:
        PreTrainedTokenizerBase: The tokenizer.
    """

    # Initialize the multi-process logger.
    try:
        from ..logging import get_logger
        from ..utils import get_path

        logger = get_logger()
        source = f'{get_path(source_file=__file__)}.{initialize_tokenizer.__name__}'
    except:
        pass

    # Initialize the tokenizer.
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=model_name)
    # Use to support LlamaMoE models.
    # This will be removed in the future.
    except Exception as e:
        if 'llama-2' in model_name.lower():
            try:
                tokenizer = AutoTokenizer.from_pretrained(
                    pretrained_model_name_or_path=
                    '/workspace/models/Llama-2-7b-hf')
            except Exception as e:
                tokenizer = AutoTokenizer.from_pretrained(
                    pretrained_model_name_or_path=
                    '/home/yxhong4nchc/Dataspace/models/Llama-2-7b-hf')
        elif 'llama-3.2' in model_name.lower():
            try:
                tokenizer = AutoTokenizer.from_pretrained(
                    pretrained_model_name_or_path=
                    '/workspace/models/Llama-3.2-3B')
            except Exception as e:
                tokenizer = AutoTokenizer.from_pretrained(
                    pretrained_model_name_or_path=
                    '/home/yxhong4nchc/Dataspace/models/Llama-3.2-3B')
        else:
            raise e

    # Set the pad token.
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

        message = 'The pad token has been set to the eos token.'
        try:
            logger.log(
                message=message,
                source=source,
            )
        except:
            print(message)

            pass

    # Add the sep token (only for MoSLE model).
    if hasattr(model, 'mosle_config'):
        mosle_config = model.mosle_config

        if (tokenizer.sep_token is None) and \
                (mosle_config.special_tokens['sep']['value'] is not None):
            key = mosle_config.special_tokens['sep']['key']
            value = mosle_config.special_tokens['sep']['value']
            tokenizer.add_special_tokens(special_tokens_dict={key: value})

            message = 'The sep token has been added to the tokenizer.'
            try:
                logger.log(
                    message=message,
                    source=source,
                )
            except:
                print(message)

                pass

    # Resize the token embeddings.
    model.resize_token_embeddings(new_num_tokens=len(tokenizer))

    # Tie the input and output embeddings if the model supports it.
    if hasattr(model, 'tie_weights'):
        model.tie_weights()

        message = 'The input and output embeddings have been tied.'
        try:
            logger.log(
                message=message,
                source=source,
            )
        except:
            print(message)

            pass

    return tokenizer
