import os
from typing import Union
from transformers import GPT2TokenizerFast, GPT2Tokenizer


# Wrapper around Hggingface tokemizers to make sure that special tokens like pad are added
class SpecialToksGPT2Tokenizer(GPT2Tokenizer):
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        *init_inputs,
        **kwargs,
    ):
        tokenizer = super().from_pretrained(
            pretrained_model_name_or_path, *init_inputs, **kwargs
        )
        special_tokens_dict = {
            "pad_token": "<|endoftext|>",
            "sep_token": "<|endoftext|>",
            "cls_token": "<|endoftext|>",
        }
        tokenizer.add_special_tokens(special_tokens_dict)
        return tokenizer
    
class SpecialToksGPT2TokenizerFast(GPT2TokenizerFast):
    slow_tokenizer_class = SpecialToksGPT2Tokenizer

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        *init_inputs,
        **kwargs,
    ):
        tokenizer = super().from_pretrained(
            pretrained_model_name_or_path, *init_inputs, **kwargs
        )
        special_tokens_dict = {
            "pad_token": "<|endoftext|>",
            "sep_token": "<|endoftext|>",
            "cls_token": "<|endoftext|>",
            "bos_token": "<|endoftext|>",
            "eos_token": "<|endoftext|>",
        }
        tokenizer.add_special_tokens(special_tokens_dict)
        return tokenizer
