from pathlib import Path
from abc import ABC
from abc import abstractmethod
import logging
import re

from auden.utils.byte_utils import byte_encode, smart_byte_decode


def tokenize_by_CJK_char(line: str) -> str:
    """
    Tokenize a line of text with CJK char.

    Note: All return characters will be upper case.

    Example:
      input = "你好世界是 hello world 的中文"
      output = "你 好 世 界 是 HELLO WORLD 的 中 文"

    Args:
      line:
        The input text.

    Return:
      A new string tokenize by CJK char.
    """
    # The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
    pattern = re.compile(
        r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
    )
    chars = pattern.split(line.strip().upper())
    return " ".join([w.strip() for w in chars if w.strip()])


class AbstractAsrTokenizer(ABC):
    """Abstract class for asr tokenizer."""
    def __init__(self, dir_or_name):
        super().__init__()

    @property
    @abstractmethod
    def vocab_size(self):
        pass
    
    @property
    @abstractmethod
    def unk_id(self):
        pass

    @abstractmethod
    def add_tokens(self, tokens):
        pass

    @abstractmethod
    def encode(self, text):
        pass
    
    @abstractmethod
    def decode(self, token_ids):
        pass

    @property
    def cls(self):
        raise NotImplementedError('CLS is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def sep(self):
        raise NotImplementedError('SEP is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def pad(self):
        raise NotImplementedError('PAD is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def eod(self):
        raise NotImplementedError('EOD is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def mask(self):
        raise NotImplementedError('MASK is not provided for {} '
                                  'tokenizer'.format(self.name))


class AsrSpmTokenizer(AbstractAsrTokenizer):
    def __init__(self, dir):
        super().__init__(dir)
        import sentencepiece as spm
        model_file = Path(dir) / "bpe.model"
        if not model_file.exists():
            raise FileNotFoundError(f"Model file not found: {model_file}")

        self._tokenizer = spm.SentencePieceProcessor(model_file=str(model_file))
        self._vocab = {
            self._tokenizer.id_to_piece(i): i
            for i in range(self._tokenizer.get_piece_size())
        }
        if "<blk>" not in self._vocab:
            self.add_tokens(["<blk>"])
            logging.warning("<blk> not included in the current spm. Now added <blk> to the end.")
        self.blank_id = self._vocab["<blk>"]
        logging.info(
            f"[Tokenizer] Loaded {self.__class__.__name__} (type=sentencepiece) "
            f"from {model_file} | vocab size: {len(self._vocab)} | blank_id: {self.blank_id} | unk_id: {self.unk_id}")
           
    def add_tokens(self, tokens):
        assert isinstance(tokens, list)
        for t in tokens:
            if t not in self._vocab:
                next_id = len(self._vocab)
                self._vocab[t] = next_id

    @property
    def vocab_size(self):
        return len(self._vocab)
    
    @property
    def unk_id(self):
        return self._tokenizer.piece_to_id('<unk>')
    
    def encode(self, text):
        return self._tokenizer.encode(text, out_type=int)
    
    def decode(self, token_ids):
        return self._tokenizer.decode(token_ids)
    
    
class AsrSpmBbpeTokenizer(AsrSpmTokenizer):
    """
    Icefall-like SpmBbpeTokenizer wrapper.
    Spaces are added before all CJK chars, and all other latin chars are converted to upper case.
    Then all CJK single char are converted to a sequence of byte chars, which essentially make a single char into a fake word.
    At last, we do a normal unigram bpe model with such text
    """
    
    def __init__(self, dir):
        super().__init__(dir)
    
    def encode(self, texts):
        texts = [byte_encode(tokenize_by_CJK_char(text)) for text in texts]
        return self._tokenizer.encode(texts, out_type=int)
    
    def decode(self, token_ids):
        texts = self._tokenizer.decode(token_ids)
        return [smart_byte_decode(text) for text in texts]
    
class AsrIcefallLexiconCharTokenizer(AbstractAsrTokenizer):
    def __init__(self, lang_dir):
        name = 'IcefallLexiconTokenizer'
        super().__init__(name)
        
        from icefall.lexicon import Lexicon
        from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
        
        self._lexicon = Lexicon(lang_dir)
        self._graph_compiler = CharCtcTrainingGraphCompiler(
            lexicon=self._lexicon,
            device='cpu'
        )
        
    def add_tokens(self, tokens):
        raise NotImplementedError('add_tokens is not provided for {} '
                                  'tokenizer'.format(self.name))
    
    @property
    def vocab_size(self):
        return max(self._lexicon.tokens) + 1
    
    @property
    def unk_id(self):
        return self._lexicon.token_table['<unk>']
    
    @property
    def blank_id(self):
        return self._lexicon.token_table["<blk>"]
    
    def encode(self, text):
        return self._graph_compiler.texts_to_ids(text)
    
    def decode(self, token_ids, use_space=False):
        if use_space:
            split_char = " "
        else:
            split_char = ""
            
        if token_ids and isinstance(token_ids[0], list):
            texts = []
            for token_id in token_ids:
                text = split_char.join([self._lexicon.token_table[i] for i in token_id])
                texts.append(text)
            return texts
        else:
            return split_char.join([self._lexicon.token_table[i] for i in token_ids])

        
class AsrTiktokenTokenizer(AbstractAsrTokenizer):
    def __init__(self, model_name):
        name = 'TiktokenTokenizer'
        super().__init__(name)

        import tiktoken
        self._tokenizer = tiktoken.get_encoding(model_name) # r50k_base, p50k_base, cl100k_base, o200k_base
        self._vocab = {self._tokenizer.decode_single_token_bytes(i): i for i in range(self._tokenizer.n_vocab)}

        # Add <blk> if not present
        if b"<blk>" not in self._vocab:
            self._vocab[b"<blk>"] = len(self._vocab)
            logging.warning("<blk> token not found in Tiktoken vocabulary; appending manually.")

        self.blank_id = self._vocab[b"<blk>"]

        logging.info(
            f"[Tokenizer] Loaded AsrTiktokenTokenizer (type=tiktoken, model={model_name}) | "
            f"vocab size: {len(self._vocab)} | blank_id: {self.blank_id}"
        )
        
    def encode(self, text):
        if isinstance(text, str):
            return self._tokenizer.encode(text)
        elif isinstance(text, list):
            return [self._tokenizer.encode(t) for t in text]
        
    def decode(self, tokens):
        if isinstance(tokens[0], list):
            return [self._tokenizer.decode(token) for token in tokens]
        else:
            return self._tokenizer.decode(tokens)
        
    @property
    def vocab_size(self):
        return len(self._vocab)
    
    
    def add_tokens(self, tokens):
        raise NotImplementedError('add_tokens is not provided for {} '
                                  'tokenizer'.format(self.name))
    
    @property
    def unk_id(self):
        raise NotImplementedError('unk_id is not provided for {} '
                                  'tokenizer'.format(self.name))
    
    
    
# class ConcatSentencePieceTokenizer(AbstractAsrTokenizer):
#     """Concatenated SentencePieceTokenizer wrapper"""

#     def __init__(self, model_file_json):
#         name = 'ConcatSentencePieceTokenizer'
#         super().__init__(name)

#         import sentencepiece
#         import json
#         with open(model_file_json) as f:
#             self.tokenizer_dict = json.load(f)
            
#         self.tokenizers = {}
#         for lang_id in self.tokenizer_dict.keys():
#             logging.info(f"Loading sentencepiece model for {lang}")
#             self.tokenizers[lang_id] = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_dict)
       
#         self.vocabulary = []

#         # the tokenizers should produce non-overlapping, ordered token ids
#         # keys are language ids
#         self.token_id_offset = {}

#         # keys are tokenizer numbers
#         self.token_id_offset_by_tokenizer_num = {}
#         offset = 0
#         i = 0
#         for lang, tokenizer in self.tokenizers.items():
#             self.token_id_offset[lang] = offset
#             self.token_id_offset_by_tokenizer_num[i] = offset
#             offset += len(tokenizer.vocab)
#             i += 1

#         for tokenizer in self.tokenizers.values():
#             self.vocabulary.extend(tokenizer.vocab)

#         self.vocab_size = len(self.vocabulary)
#         logging.info(f'Concat vocab size: {self.vocab_size}')

#         # lookup tables to speed up token to text operations
#         # if there are two tokenizers, [0,1], ['en', 'es'], each with 128 tokens, the aggregate tokenizer
#         # token range will be [0,255]. The below method provides three look up tables:
#         # one, to convert the incoming token id -- e.g. 200 into its real id (200-127 = 73)
#         # second, to compute the tokenizer id that should process that token (1)
#         # third, the compute the lang id for that token ('es')
#         offset_token_ids_by_token_id, tokenizers_by_token_id, langs_by_token_id = self._calculate_offsets()

#         self.offset_token_ids_by_token_id = offset_token_ids_by_token_id
#         self.tokenizers_by_token_id = tokenizers_by_token_id
#         self.langs_by_token_id = langs_by_token_id

#     def _calculate_offsets(self):
#         offsets = {}
#         tokenizers = {}
#         langs = {}
#         cur_num = 0
#         tot = len(self.tokenizers)
#         for id in range(len(self.vocabulary)):
#             off_id = id - list(self.token_id_offset.values())[cur_num]
#             if cur_num + 1 < tot:
#                 if id >= list(self.token_id_offset.values())[cur_num + 1]:
#                     cur_num += 1
#                     off_id = id - list(self.token_id_offset.values())[cur_num]
#             offsets[id] = off_id
#             tokenizers[id] = list(self.tokenizers.values())[cur_num]
#             langs[id] = list(self.tokenizers.keys())[cur_num]

#         return offsets, tokenizers, langs

#     @property
#     def vocab_size(self):
#         return len(self.vocabulary)
    
#     @property
#     def unk_id(self):
#         raise NotImplementedError('unk_id is not provided for {} '
#                                   'tokenizer'.format(self.name))
    
#     @property
#     def blank_id(self):
#         if '<blk>' in self.vocab.keys():
#             return self.vocabulary['<blk>']
#         else:
#             self.add_tokens(['<blk>'])
#             logging.warning(f'<blk> not included in the current spm. Now add <blk> to the end.')
#             return self.vocabulary['<blk>']
    
#     def encode(self, text, lang_id=None):
#         raise NotImplementedError('encode is not provided for {} '
#                                   'tokenizer'.format(self.name))
#         tokenizer = self.tokenizers[lang_id]
#         token_ids = tokenizer.text_to_ids(text)
#         token_ids[:] = [t + self.token_id_offset[lang_id] for t in token_ids]

#         return token_ids
    
#     def decode(self, tokens):
#         if isinstance(ids, np.ndarray):
#             ids = ids.tolist()

#         if not isinstance(ids[0], list):
#             ids = [ids]
#         texts = []
#         for id_seq in ids: 
#             tokens = []
#             for id in id_seq:
#                 offset_id = self.offset_token_ids_by_token_id[id]
#                 tokenizer = self.tokenizers_by_token_id[id]
#                 tokens.extend(tokenizer.id_to_piece(offset_id))
#             text = ''.join(tokens).replace('▁', ' ')
#             texts.append(text)

#         return texts

# if __name__ == '__main__':
#     logging.basicConfig(
#     level=logging.INFO,        # Show info and above
#     format='[%(levelname)s] %(message)s')
#     from auden.auto.auto_tokenizer import AutoTokenizer
#     # tokenizer = AsrSpmTokenizer('/apdcephfs_cq10/share_1603164/user/yiwenyshao/independent/auden/egs/asr/data/lang_owsm')
#     # tokenizer = AsrTiktokenTokenizer('r50k_base')
#     tokenizer = AutoTokenizer.from_pretrained('asr-spm-bbpe', '/apdcephfs_cq10/share_1603164/user/yiwenyshao/independent/auden/egs/asr/data/lang_bbpe_2000')
#     # tokenizer = AutoTokenizer.from_pretrained('asr-spm', '/apdcephfs_cq10/share_1603164/user/yiwenyshao/independent/auden/egs/asr/data/lang_owsm')
#     # tokenizer = AutoTokenizer.from_pretrained('asr-tiktoken', 'r50k_base')
#     ids = tokenizer.encode(['你好世界, hello world'])
#     print(ids)
#     # print([tokenizer._tokenizer.id_to_piece(id) for id in ids])