"""Utilities related to tokenizers."""
from transformers import AutoTokenizer, GPTNeoXTokenizerFast, GPT2TokenizerFast, PreTrainedTokenizer


###############################################################################


def from_pretrained(name: str) -> PreTrainedTokenizer:
    """A wrapper around AutoTokenizer.from_pretrained.

    Fixes some things for particular tokenizers.
    """
    tokenizer = AutoTokenizer.from_pretrained(name)

    if isinstance(tokenizer, GPTNeoXTokenizerFast):
        # I've seen these missing a pad token id despite having a pad token.
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.vocab['<|padding|>']
        if tokenizer.eos_token_id is None:
            tokenizer.eos_token_id = tokenizer.vocab['<|endoftext|>']

    elif isinstance(tokenizer, GPT2TokenizerFast):
        # These don't have a pad token. I can't assign a negative value to pad_token_id as that
        # will cause an exception. Hence, I'm setting it to some special token. I don't know
        # its use, but it seems like something I'm not gonna care about.
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.vocab['<empty_output>']

    return tokenizer
