# Copyright      2024  Yiwen Shao
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
from tqdm.auto import tqdm
import os
import numpy as np

from abc import ABC
from abc import abstractmethod

from .byte_utils import byte_encode, smart_byte_decode
from .icefall_utils import tokenize_by_CJK_char
from .text_normalization import text_normalization

class AbstractTokenizer(ABC):
    """Abstract class for tokenizer."""

    def __init__(self, name):
        self.name = name
        super().__init__()

    @property
    @abstractmethod
    def vocab_size(self):
        pass
    
    @property
    @abstractmethod
    def unk_id(self):
        pass
    
    @property
    @abstractmethod
    def blank_id(self):
        pass

    @abstractmethod
    def add_tokens(self, tokens):
        pass

    @abstractmethod
    def encode(self, text):
        pass
    
    @abstractmethod
    def decode(self, token_ids):
        pass
    
    
    # @property
    # @abstractmethod
    # def vocab(self):
    #     """Dictionary from vocab text token to id token."""
    #     pass


    @property
    def cls(self):
        raise NotImplementedError('CLS is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def sep(self):
        raise NotImplementedError('SEP is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def pad(self):
        raise NotImplementedError('PAD is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def eod(self):
        raise NotImplementedError('EOD is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def mask(self):
        raise NotImplementedError('MASK is not provided for {} '
                                  'tokenizer'.format(self.name))


class _SentencePieceTokenizer(AbstractTokenizer):
    """SentencePieceTokenizer-Megatron wrapper"""

    def __init__(self, model_file, is_bbpe=False):
        name = 'SentencePieceTokenizer'
        super().__init__(name)

        import sentencepiece
        self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
        self._vocab = {}
        for i in range(len(self._tokenizer)):
            t = self._tokenizer.id_to_piece(i)
            self._vocab[t] = i
        self.is_bbpe = is_bbpe
            
    def add_tokens(self, tokens):
        assert isinstance(tokens, list)
        for t in tokens:
            if t not in self._vocab:
                next_id = len(self._vocab)
                self._vocab[t] = next_id

    @property
    def vocab_size(self):
        return len(self._vocab)
    
    @property
    def unk_id(self):
        return self._tokenizer.piece_to_id('<unk>')
    
    @property
    def blank_id(self):
        if '<blk>' in self._vocab.keys():
            return self._vocab['<blk>']
        else:
            self.add_tokens(['<blk>'])
            logging.warning(f'<blk> not included in the current spm. Now add <blk> to the end.')
            return self._vocab['<blk>']
    
    def encode(self, text):
        return self._tokenizer.encode(text, out_type=int)
    
    def decode(self, token_ids):
        return self._tokenizer.decode(token_ids)
    
class _IcefallLexiconCharTokenizer(AbstractTokenizer):
    def __init__(self, lang_dir):
        name = 'IcefallLexiconTokenizer'
        super().__init__(name)
        
        from icefall.lexicon import Lexicon
        from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
        
        self._lexicon = Lexicon(lang_dir)
        self._graph_compiler = CharCtcTrainingGraphCompiler(
            lexicon=self._lexicon,
            device='cpu'
        )
        
    def add_tokens(self, tokens):
        raise NotImplementedError('add_tokens is not provided for {} '
                                  'tokenizer'.format(self.name))
    
    @property
    def vocab_size(self):
        return max(self._lexicon.tokens) + 1
    
    @property
    def unk_id(self):
        return self._lexicon.token_table['<unk>']
    
    @property
    def blank_id(self):
        return self._lexicon.token_table["<blk>"]
    
    def encode(self, text):
        return self._graph_compiler.texts_to_ids(text)
    
    def decode(self, token_ids, use_space=False):
        if use_space:
            split_char = " "
        else:
            split_char = ""
            
        if token_ids and isinstance(token_ids[0], list):
            texts = []
            for token_id in token_ids:
                text = split_char.join([self._lexicon.token_table[i] for i in token_id])
                texts.append(text)
            return texts
        else:
            return split_char.join([self._lexicon.token_table[i] for i in token_ids])


class _IcefallBbpeTokenizer(AbstractTokenizer):
    """IcefallBbpeTokenizer wrapper"""

    def __init__(self, model_file):
        name = 'IcefallBbpeTokenizer'
        super().__init__(name)

        import sentencepiece
        self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
        self._vocab = {}
        for i in range(len(self._tokenizer)):
            t = self._tokenizer.id_to_piece(i)
            self._vocab[t] = i
            
    def add_tokens(self, tokens):
        assert isinstance(tokens, list)
        for t in tokens:
            if t not in self._vocab:
                next_id = len(self._vocab)
                self._vocab[t] = next_id

    @property
    def vocab_size(self):
        return len(self._vocab)
    
    @property
    def unk_id(self):
        return self._tokenizer.piece_to_id('<unk>')
    
    @property
    def blank_id(self):
        if '<blk>' in self._vocab.keys():
            return self._vocab['<blk>']
        else:
            self.add_tokens(['<blk>'])
            logging.warning(f'<blk> not included in the current spm. Now add <blk> to the end.')
            return self._vocab['<blk>']
    
    def encode(self, texts):
        texts = [byte_encode(tokenize_by_CJK_char(text)) for text in texts]
        return self._tokenizer.encode(texts, out_type=int)
    
    def decode(self, token_ids):
        texts = self._tokenizer.decode(token_ids)
        return [smart_byte_decode(text) for text in texts]

        
class _TiktokenTokenizer(AbstractTokenizer):
    def __init__(self, model_name):
        name = 'TiktokenTokenizer'
        super().__init__(name)

        import tiktoken
        self._tokenizer = tiktoken.get_encoding(model_name) # r50k_base, p50k_base, cl100k_base, o200k_base
        
    def encode(self, text):
        if isinstance(text, str):
            return self._tokenizer.encode(text)
        elif isinstance(text, list):
            return [self._tokenizer.encode(t) for t in text]
        
    def decode(self, tokens):
        if isinstance(tokens[0], list):
            return [self._tokenizer.decode(token) for token in tokens]
        else:
            return self._tokenizer.decode(tokens)
        
    @property
    def vocab_size(self):
        return self._tokenizer.n_vocab + 1 # add 1 more for <blk>
    
    
    def add_tokens(self, tokens):
        raise NotImplementedError('add_tokens is not provided for {} '
                                  'tokenizer'.format(self.name))
    
    @property
    def unk_id(self):
        raise NotImplementedError('unk_id is not provided for {} '
                                  'tokenizer'.format(self.name))
        
    @property
    def blank_id(self):
        return self._tokenizer.n_vocab
    
    
    
class _ConcatSentencePieceTokenizer(AbstractTokenizer):
    """Concatenated SentencePieceTokenizer wrapper"""

    def __init__(self, model_file_json):
        name = 'ConcatSentencePieceTokenizer'
        super().__init__(name)

        import sentencepiece
        import json
        with open(model_file_json) as f:
            self.tokenizer_dict = json.load(model_file_json)
            
        self.tokenizers = {}
        for lang_id in self.tokenizer_dict.keys():
            logging.info(f"Loading sentencepiece model for {lang}")
            self.tokenizers[lang_id] = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_dict)
       
        self.vocabulary = []

        # the tokenizers should produce non-overlapping, ordered token ids
        # keys are language ids
        self.token_id_offset = {}

        # keys are tokenizer numbers
        self.token_id_offset_by_tokenizer_num = {}
        offset = 0
        i = 0
        for lang, tokenizer in self.tokenizers.items():
            self.token_id_offset[lang] = offset
            self.token_id_offset_by_tokenizer_num[i] = offset
            offset += len(tokenizer.vocab)
            i += 1

        for tokenizer in self.tokenizers.values():
            self.vocabulary.extend(tokenizer.vocab)

        self.vocab_size = len(self.vocabulary)
        logging.info(f'Concat vocab size: {self.vocab_size}')

        # lookup tables to speed up token to text operations
        # if there are two tokenizers, [0,1], ['en', 'es'], each with 128 tokens, the aggregate tokenizer
        # token range will be [0,255]. The below method provides three look up tables:
        # one, to convert the incoming token id -- e.g. 200 into its real id (200-127 = 73)
        # second, to compute the tokenizer id that should process that token (1)
        # third, the compute the lang id for that token ('es')
        offset_token_ids_by_token_id, tokenizers_by_token_id, langs_by_token_id = self._calculate_offsets()

        self.offset_token_ids_by_token_id = offset_token_ids_by_token_id
        self.tokenizers_by_token_id = tokenizers_by_token_id
        self.langs_by_token_id = langs_by_token_id

    def _calculate_offsets(self):
        offsets = {}
        tokenizers = {}
        langs = {}
        cur_num = 0
        tot = len(self.tokenizers)
        for id in range(len(self.vocabulary)):
            off_id = id - list(self.token_id_offset.values())[cur_num]
            if cur_num + 1 < tot:
                if id >= list(self.token_id_offset.values())[cur_num + 1]:
                    cur_num += 1
                    off_id = id - list(self.token_id_offset.values())[cur_num]
            offsets[id] = off_id
            tokenizers[id] = list(self.tokenizers.values())[cur_num]
            langs[id] = list(self.tokenizers.keys())[cur_num]

        return offsets, tokenizers, langs

    @property
    def vocab_size(self):
        return len(self.vocabulary)
    
    @property
    def unk_id(self):
        raise NotImplementedError('unk_id is not provided for {} '
                                  'tokenizer'.format(self.name))
    
    @property
    def blank_id(self):
        if '<blk>' in self.vocab.keys():
            return self.vocabulary['<blk>']
        else:
            self.add_tokens(['<blk>'])
            logging.warning(f'<blk> not included in the current spm. Now add <blk> to the end.')
            return self.vocabulary['<blk>']
    
    def encode(self, text, lang_id=None):
        raise NotImplementedError('encode is not provided for {} '
                                  'tokenizer'.format(self.name))
        tokenizer = self.tokenizers[lang_id]
        token_ids = tokenizer.text_to_ids(text)
        token_ids[:] = [t + self.token_id_offset[lang_id] for t in token_ids]

        return token_ids
    
    def decode(self, tokens):
        if isinstance(ids, np.ndarray):
            ids = ids.tolist()

        if not isinstance(ids[0], list):
            ids = [ids]
        texts = []
        for id_seq in ids: 
            tokens = []
            for id in id_seq:
                offset_id = self.offset_token_ids_by_token_id[id]
                tokenizer = self.tokenizers_by_token_id[id]
                tokens.extend(tokenizer.id_to_piece(offset_id))
            text = ''.join(tokens).replace('▁', ' ')
            texts.append(text)

        return texts



class TokenizerModule:
    def __init__(self, args: argparse.Namespace):
        self.args = args

    @classmethod
    def add_arguments(cls, parser: argparse.ArgumentParser):
        group = parser.add_argument_group(
            title="Tokenizer Initilization",
            description="These options are used for the building "
            "different types of tokenizers "
        )
        
        group.add_argument(
            "--tokenizer-model-path",
            type=str,
            help="Path to a tokenizer model file or a icefall lang dir",
        )
        
        group.add_argument(
            "--tokenizer-type",
            type=str,
            default='spm',
            help="tokenizer type"
        )
    
    def tokenizer(self):
        if self.args.tokenizer_type == 'spm':
            logging.info("Initializing Sentencepiece Tokenizer")
            tokenizer = _SentencePieceTokenizer(self.args.tokenizer_model_path)
        if self.args.tokenizer_type == 'icefall_bbpe':
            logging.info("Initializing Icefall BBPE Tokenizer")
            tokenizer = _IcefallBbpeTokenizer(self.args.tokenizer_model_path)
        elif self.args.tokenizer_type == 'lexicon':
            logging.info("Initializing Icefall Lexicon")
            tokenizer = _IcefallLexiconCharTokenizer(self.args.tokenizer_model_path)
        elif self.args.tokenizer_type == 'tiktoken':
            logging.info("Initializing Tiktoken Tokenizer")
            tokenizer = _TiktokenTokenizer(self.args.tokenizer_model_path)
        
        return tokenizer
        
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    TokenizerModule.add_arguments(parser)
    args = parser.parse_args()
    tokenizer_module = TokenizerModule(args)
    tokenizer = tokenizer_module.tokenizer()
    #tokenizer.add_tokens(['<blk>'])
    print(f"vocab size: {tokenizer.vocab_size}")
    # print(f"unk id: {tokenizer.unk_id}")
    print(f"blank id: {tokenizer.blank_id}")
    # words = ['ทำให้ มี ความมั่นใจ มากยิ่งขึ้น ขึ้น']
    words = ["湖北一公司以公司名义寨"]
    #words = ['HOW DO WE GET HERE   HOW DO WE ALWAYS GET HERE WHEN WE WE KIND OF ARE TALKING   HOW DO WE ALWAYS GET TO THIS PLACE']
    #words = ["你好, 李华", "你 好, Hello WORLD李华!"]
    word_ids = tokenizer.encode(words)
    print(f"input text: {words}")
    print(f"encode ids: {word_ids}, token length: {[len(word_id) for word_id in word_ids]}")
    print(f"decode result: {tokenizer.decode(word_ids)}")
    if args.tokenizer_type == 'spm':
        word_pieces = [tokenizer._tokenizer.id_to_piece(ids) for ids in word_ids]
        print(f"word pieces: {word_pieces}")
        
    tn_words = [text_normalization(text, case='lower', space_between_cjk=False) for text in words]
    tn_words_ids = tokenizer.encode(tn_words)
    print(f"normalized text: {tn_words}")
    print(f"encode tn ids: {tn_words_ids}, token length: {[len(word_id) for word_id in tn_words_ids]}")
    print(f"decode result: {tokenizer.decode(tn_words_ids)}")
    if args.tokenizer_type == 'spm':
        word_pieces = [tokenizer._tokenizer.id_to_piece(ids) for ids in tn_words_ids]
        print(f"word pieces: {word_pieces}")