

"""Megatron tokenizers."""
import argparse 
from abc import ABC 
from abc import abstractmethod 
import os 
from typing import Optional 

from llama_recipes .utils .distributed import is_rank_0 
from megatron_lm .megatron .core .datasets .megatron_tokenizer import MegatronTokenizer 


def build_tokenizer (args :argparse .Namespace ):
    """Initialize tokenizer."""
    if is_rank_0 ():
        print ("> building {} tokenizer ...".format (args .tokenizer_type ),flush =True )


    if args .tokenizer_type =="SentencePieceTokenizer":
        assert args .tokenizer_model is not None 
        tokenizer =_SentencePieceTokenizer (
        args .tokenizer_model ,vocab_extra_ids =args .vocab_extra_ids 
        )
    elif args .tokenizer_type =="GPTSentencePieceTokenizer":
        assert args .tokenizer_model is not None 
        tokenizer =_GPTSentencePieceTokenizer (args .tokenizer_model )
    elif args .tokenizer_type =="Llama2Tokenizer":
        assert args .tokenizer_model is not None 
        tokenizer =_Llama2Tokenizer (args .tokenizer_model )
    elif args .tokenizer_type =='Llama3Tokenizer':
        assert args .tokenizer_model is not None 
        tokenizer =_Llama3Tokenizer (args .tokenizer_model )
    elif args .tokenizer_type =='NullTokenizer':
        assert args .vocab_size is not None 
        tokenizer =_NullTokenizer (args .vocab_size )
    else :
        raise NotImplementedError (
        "{} tokenizer is not implemented.".format (args .tokenizer_type )
        )


    if getattr (args ,"padded_vocab_size",None )is None :
        args .padded_vocab_size =_vocab_size_with_padding (
        orig_vocab_size =tokenizer .vocab_size ,args =args 
        )

    return tokenizer 


def _vocab_size_with_padding (orig_vocab_size ,args ):
    """Pad vocab size so it is divisible by model parallel size and
    still having GPU friendly size."""

    after =orig_vocab_size 
    multiple =args .make_vocab_size_divisible_by 
    while (after %multiple )!=0 :
        after +=1 
    if is_rank_0 ():
        print (
        " > padded vocab (size: {}) with {} dummy tokens (new size: {})".format (
        orig_vocab_size ,after -orig_vocab_size ,after 
        ),
        flush =True ,
        )
    return after 


class AbstractTokenizer (ABC ):
    """Abstract class for tokenizer."""

    def __init__ (self ,name ):
        self .name =name 
        super ().__init__ ()

    @property 
    @abstractmethod 
    def vocab_size (self ):
        pass 

    @property 
    @abstractmethod 
    def vocab (self ):
        """Dictionary from vocab text token to id token."""
        pass 

    @property 
    @abstractmethod 
    def inv_vocab (self ):
        """Dictionary from vocab id token to text token."""
        pass 

    @abstractmethod 
    def tokenize (self ,text ):
        pass 

    def detokenize (self ,token_ids ):
        raise NotImplementedError (
        "detokenizer is not implemented for {} tokenizer".format (self .name )
        )

    @property 
    def cls (self ):
        raise NotImplementedError (
        "CLS is not provided for {} tokenizer".format (self .name )
        )

    @property 
    def sep (self ):
        raise NotImplementedError (
        "SEP is not provided for {} tokenizer".format (self .name )
        )

    @property 
    def pad (self ):
        raise NotImplementedError (
        "PAD is not provided for {} tokenizer".format (self .name )
        )

    @property 
    def eod (self ):
        raise NotImplementedError (
        "EOD is not provided for {} tokenizer".format (self .name )
        )

    @property 
    def mask (self ):
        raise NotImplementedError (
        "MASK is not provided for {} tokenizer".format (self .name )
        )


class _SentencePieceTokenizer (AbstractTokenizer ):
    """SentencePieceTokenizer-Megatron wrapper"""

    def __init__ (self ,model_file ,vocab_extra_ids =0 ):
        name ="SentencePieceTokenizer"
        super ().__init__ (name )

        import sentencepiece 

        self .tokenizer =sentencepiece .SentencePieceProcessor (
        model_file =model_file 
        )
        self ._initialize (vocab_extra_ids )

    def _populate_vocab (self ):
        self ._vocab ={}
        self ._inv_vocab ={}

        for i in range (len (self .tokenizer )):
            t =self .tokenizer .id_to_piece (i )
            self ._inv_vocab [i ]=t 
            self ._vocab [t ]=i 

    def _initialize (self ,vocab_extra_ids ):
        self ._populate_vocab ()
        self ._special_tokens ={}
        self ._inv_special_tokens ={}

        self ._t5_tokens =[]

        def _add_special_token (t ):
            if t not in self ._vocab :
                next_id =len (self ._vocab )
                self ._vocab [t ]=next_id 
                self ._inv_vocab [next_id ]=t 
            self ._special_tokens [t ]=self ._vocab [t ]
            self ._inv_special_tokens [self ._vocab [t ]]=t 

        _add_special_token ("<CLS>")
        self ._cls_id =self ._vocab ["<CLS>"]
        _add_special_token ("<SEP>")
        self ._sep_id =self ._vocab ["<SEP>"]
        _add_special_token ("<EOD>")
        self ._eod_id =self ._vocab ["<EOD>"]
        _add_special_token ("<MASK>")
        self ._mask_id =self ._vocab ["<MASK>"]

        pad_id =self .tokenizer .pad_id ()
        try :
            pad_token =self .tokenizer .id_to_piece (pad_id )
        except IndexError :
            pad_token ="<PAD>"
        _add_special_token (pad_token )
        self ._pad_id =self ._vocab [pad_token ]

        bos_id =self .tokenizer .bos_id ()
        try :
            bos_token =self .tokenizer .id_to_piece (bos_id )
        except IndexError :
            bos_token ="<BOS>"
        _add_special_token (bos_token )
        self ._bos_id =self ._vocab [bos_token ]

        eos_id =self .tokenizer .eos_id ()
        try :
            eos_token =self .tokenizer .id_to_piece (eos_id )
        except IndexError :
            eos_token ="<EOS>"
        _add_special_token (eos_token )
        self ._eos_id =self ._vocab [eos_token ]

        for i in range (vocab_extra_ids ):
            t ="<extra_id_{}>".format (i )
            _add_special_token (t )
            self ._t5_tokens +=[t ]

    @property 
    def vocab_size (self ):
        return len (self ._vocab )

    @property 
    def vocab (self ):
        return self ._vocab 

    @property 
    def inv_vocab (self ):
        return self ._inv_vocab 

    @property 
    def decoder (self ):
        return self ._inv_vocab 

    @property 
    def encoder (self ):
        return self ._vocab 



    def tokenize (self ,text ):
        ids =[]
        idx =0 

        while 1 :
            indices ={}
            for token in self ._special_tokens :
                try :
                    indices [token ]=text [idx :].index (token )
                except ValueError :
                    continue 
            if len (indices )==0 :
                break 

            next_token =min (indices ,key =indices .get )
            next_idx =idx +indices [next_token ]

            ids .extend (self .tokenizer .encode_as_ids (text [idx :next_idx ]))
            ids .append (self ._special_tokens [next_token ])
            idx =next_idx +len (next_token )

        ids .extend (self .tokenizer .encode_as_ids (text [idx :]))
        return ids 



    def detokenize (self ,ids ):
        text =""
        last_i =0 

        for i ,id in enumerate (ids ):
            if id in self ._inv_special_tokens :
                text +=self .tokenizer .decode_ids (ids [last_i :i ])+" "
                text +=self ._inv_special_tokens [id ]+" "
                last_i =i +1 

        text +=self .tokenizer .decode_ids (ids [last_i :])
        return text 

    @property 
    def cls (self ):
        return self ._cls_id 

    @property 
    def sep (self ):
        return self ._sep_id 

    @property 
    def pad (self ):
        return self ._pad_id 

    @property 
    def bos_token_id (self ):
        return self ._bos_id 

    @property 
    def bos (self ):
        return self ._bos_id 

    @property 
    def eod (self ):
        return self ._eod_id 

    @property 
    def eos_token_id (self ):
        return self ._eos_id 

    @property 
    def eos (self ):
        return self ._eos_id 

    @property 
    def mask (self ):
        return self ._mask_id 

    @property 
    def additional_special_tokens_ids (self ):
        return [self .vocab [k ]for k in self ._t5_tokens ]


class _GPTSentencePieceTokenizer (_SentencePieceTokenizer ):
    """SentencePieceTokenizer-Megatron wrapper"""

    def __init__ (
    self ,
    model_file ,
    ):
        super ().__init__ (model_file ,vocab_extra_ids =0 )

    def _initialize (self ,vocab_extra_ids ):
        self ._populate_vocab ()

        self ._pad_id =self .tokenizer .pad_id ()
        self ._bos_id =self .tokenizer .bos_id ()
        self ._eos_id =self .tokenizer .eos_id ()

    def tokenize (self ,text ):
        return self .tokenizer .encode_as_ids (text )

    def detokenize (self ,ids ):
        return self .tokenizer .decode_ids (ids )

    @property 
    def cls (self ):
        return -1 

    @property 
    def sep (self ):
        return -1 

    @property 
    def mask (self ):
        return -1 

    @property 
    def eod (self ):
        return self ._eos_id 

    @property 
    def additional_special_tokens_ids (self ):
        return None 


class _Llama2Tokenizer (_SentencePieceTokenizer ):
    """SentencePieceTokenizer-Megatron wrapper"""

    def __init__ (
    self ,
    model_file ,
    ):
        super ().__init__ (model_file ,vocab_extra_ids =0 )

    def _initialize (self ,vocab_extra_ids ):
        self ._populate_vocab ()


        self .n_words :int =self .tokenizer .vocab_size ()
        self .bos_id :int =self .tokenizer .bos_id ()
        self .eos_id :int =self .tokenizer .eos_id ()
        self .pad_id :int =self .tokenizer .pad_id ()
        assert self .tokenizer .vocab_size ()==self .tokenizer .get_piece_size ()

    def tokenize (self ,s :str ,bos =True ,eos =False ):
        """Default args for text completion, not chat/dialog."""
        assert type (s )is str 
        t =self .tokenizer .encode (s )
        if bos :
            t =[self .bos_id ]+t 
        if eos :
            t =t +[self .eos_id ]
        return t 

    def detokenize (self ,ids ):
        return self .tokenizer .decode_ids (ids )

    @property 
    def cls (self ):
        return -1 

    @property 
    def sep (self ):
        return -1 

    @property 
    def mask (self ):
        return -1 

    @property 
    def eod (self ):
        return self .eos_id 

    @property 
    def additional_special_tokens_ids (self ):
        return None 


class _Llama3Tokenizer (MegatronTokenizer ):
    def __init__ (self ,model_file :str ,vocab_extra_ids =0 )->None :
        self .name ="Llama3Tokenizer"
        super ().__init__ (model_file ,vocab_extra_ids =vocab_extra_ids )

        from transformers import AutoTokenizer 
        self .tokenizer =AutoTokenizer .from_pretrained (
        pretrained_model_name_or_path =os .path .dirname (model_file )
        )
        self .bos_id :Optional [int ]=self .tokenizer .bos_token_id 
        self .eos_id :Optional [int ]=self .tokenizer .eos_token_id 
        self .pad_id :Optional [int ]=self .tokenizer .pad_token_id 

        assert self .tokenizer .pad_token_id is None 
        assert self .tokenizer .bos_token_id is not None and self .tokenizer .bos_token_id ==128000 
        assert self .tokenizer .eos_token_id is not None and (self .tokenizer .eos_token_id ==128001 or self .tokenizer .eos_token_id ==128009 )
        assert len (self .tokenizer )>=128256 ,f"vocab_size: {len (self .tokenizer )}"

    def tokenize (self ,text :str ,bos =True ,eos =False ):
        '''Default args for text completion, not chat/dialog.'''
        assert type (text )is str 
        t =self .tokenizer .encode (text ,add_special_tokens =False )
        if bos and self .bos_id is not None :
            t =[self .bos_id ]+t 
        if eos and self .eos_id is not None :
            t =t +[self .eos_id ]
        return t 

    def detokenize (self ,ids :list [int ]):
        return self .tokenizer .decode (ids ,skip_special_tokens =True )

    @property 
    def cls (self ):
        return -1 

    @property 
    def sep (self ):
        return -1 

    @property 
    def mask (self ):
        return -1 

    @property 
    def eod (self ):
        return self .tokenizer .eos_token_id 

    @property 
    def additional_special_tokens_ids (self ):
        return None 

    @property 
    def vocab (self ):
        return self .tokenizer .get_vocab ()

    @property 
    def inv_vocab (self ):
        return {v :k for k ,v in self .tokenizer .get_vocab ().items ()}

    @property 
    def vocab_size (self ):
        return len (self .tokenizer )


class _NullTokenizer :
    def __init__ (self ,vocab_size ):
        vocab_size =int (vocab_size )
        self ._eos_id =vocab_size 
        self .vocab_size =vocab_size +1 

    def tokenize (self ,text ):
        return [int (x )for x in text .split (" ")]

    def detokenize (self ,ids ):
        text =[str (x )for x in ids ]
        return " ".join (text )

    @property 
    def cls (self ):
        return -1 

    @property 
    def sep (self ):
        return -1 

    @property 
    def mask (self ):
        return -1 

    @property 
    def eod (self ):
        return self ._eos_id 

    @property 
    def additional_special_tokens_ids (self ):
        return None 


class HFPreTrainedTokenizer (AbstractTokenizer ):
    """
    Huggingface Format Pretrained Tokenizer
    """

    def __init__ (self ,model_file :str ,trust_remote_code :bool =False )->None :
        name ="HFPreTrainedTokenizer"
        super ().__init__ (name )

        from transformers import AutoTokenizer 

        self .tokenizer =AutoTokenizer .from_pretrained (
        pretrained_model_name_or_path =model_file ,
        trust_remote_code =trust_remote_code ,
        )
        special_tokens_dict :dict [str ,str ]={
        "eod_token":"<EOD>",
        }
        if self .tokenizer .pad_token is None :
            self .tokenizer .add_special_tokens ({"pad_token":"<pad>"})

        self .tokenizer .add_special_tokens (special_tokens_dict )

    @property 
    def vocab_size (self )->int :
        return self .tokenizer .vocab_size 

    @property 
    def vocab (self ):
        return self .tokenizer .get_vocab ()

    @property 
    def inv_vocab (self ):
        return self .tokenizer .decoder 

    def tokenize (self ,text :str ):
        return self .tokenizer .encode (text )

    def detokenize (self ,token_ids )->str :
        return self .tokenizer .decode (token_ids )

    @property 
    def eod (self ):
        return self .tokenizer .convert_tokens_to_ids (
        self .tokenizer .eod_token 
        )

    @property 
    def pad (self )->int :
        return self .tokenizer .pad_token_id 
