'''
Author: Jiashu Li
Date: 2022-06-15 15:49:32
LastEditTime: 2022-08-27 18:29:41
LastEditors: Jiashu Li
Description: 
FilePath: /asr_lm/utils/text.py
'''
import os
import sys
import re
import random
import sentencepiece as spm
from typing import List, Dict, Optional, Tuple, Any

sys.path.append(os.getcwd())

from utils.common import add_cn_en_symbol, has_tag


def tokenize(text: str,
             token_type: str='word_eng',
             bpe_model: Optional[str]=None,
             add_cs_symbol: bool=False,
             to_lower: bool=False) -> Tuple[List, List, List]:
    """
    Description:
        covert word sequence to List[str],
        for text filter, text process .etc
    Args:
        text: one line
    Return:
        tokens: 中文加英文单词token
        cn_tokens: 只含中文的token
        en_tokens: 只含英文单词的token
    """
    assert token_type in ['word_eng', 'bpe']
    pattern_cn = re.compile("([\u4e00-\u9fff])")
    chars = pattern_cn.split(text)
    mix_chars = [w for w in chars if len(w.strip()) > 0]
    tokens = []
    cn_tokens = []
    en_tokens = []
    if token_type == 'word_eng':
        for cn_or_en in mix_chars:
            if pattern_cn.fullmatch(cn_or_en) is not None:
                tokens.append(cn_or_en)
                cn_tokens.append(cn_or_en)
            else:
                for w in cn_or_en.strip().split(" "):
                    if w:
                        tokens.append(w)
                        en_tokens.append(w)
    elif token_type == 'bpe':
        sp = spm.SentencePieceProcessor()
        sp.load(bpe_model)
        for cn_or_en in mix_chars:
            if pattern_cn.fullmatch(cn_or_en) is not None:
                tokens.append(cn_or_en)
                cn_tokens.append(cn_or_en)
            else:
                for w in cn_or_en.strip().split(' '):
                    if w:
                        for p in sp.encode_as_pieces(w):
                            tokens.append(p)
                            en_tokens.append(p)
    else:
        raise ValueError
    
    if add_cs_symbol:
        tokens = add_cn_en_symbol(tokens, to_lower=to_lower)
    
    return tokens, cn_tokens, en_tokens


def sentence_generalization(txt: str,
                            sub_dict: Dict,
                            remove_dup: bool=False,
                            sub_num: int=None) -> List[str]:
    """
    句式泛化，同类词替换
    Args:
        txt: str 输入文本
        sub_dict: Dict 替换字典，key=类别, value: List
        remove_dup: bool 是否移除重复替换，如：
                            从南京寄快递到南通邮费一般多少
                            泛化后是否包含
                            从南京寄快递到南京邮费一般多少

    NOTE: 东 与 东北 同时出现，只替换 东北

    example:
        >>> txt: 到了七期g栋楼吗
        >>> sub_dict: {
                    "char": [chr(x) for x in range(ord('a'), ord('z') + 1)],
                    "o_direction": ["西北", "东北", "西南", "东南"],
                    "p_direction": ["东", "南", "西", "北"],
                    "number": list("二三四五六七八九十"),
                    "location": ["地铁站", "商场"]
                    }
                    
        >>> return: ['到了二期a栋楼吗',..., '到了九期z栋楼吗', '到了十期z栋楼吗']
    
    """
    sub_info = []
    cache_d = []
    for sub_cls, sub_strs in sub_dict.items():
        for sub_str in sub_strs:
            sub_idx = txt.find(sub_str)
            if sub_idx >= 0:
                if sub_cls not in cache_d:
                    cache_d.append(sub_cls)
                # TODO: 泛化至一般情况
                if "o_direction" in cache_d and sub_cls == "p_direction":
                    continue
                sub_info.append((sub_cls, sub_idx, sub_str))
    
    p = []
    sub_p = [txt]
    if len(sub_info) > 0:
        sub_info.sort(key=lambda x: x[1])
        for info in sub_info:
            p = sub_p[:-1]
            for _p in sub_p[-1].split(info[-1]):
                p.append(_p)
            sub_p = p
        
        assert len(p) - 1 == len(sub_info)

        replace_info = []
        for s_info in sub_info:
            new_sub_ls = sub_dict[s_info[0]]

            if sub_num:
                random.shuffle(new_sub_ls)
                new_sub_ls = new_sub_ls[:sub_num]

            if len(replace_info) > 0:
                cache = []
                for new_i in new_sub_ls:
                    for i in replace_info:
                        if isinstance(i, List):
                            c = i+[new_i]
                        elif isinstance(i, str):
                            c = [i, new_i]
                        else:
                            raise ValueError
                        if remove_dup and len(c) != len(list(set(c))):
                            continue
                        cache.append(c)
            else:
                cache = new_sub_ls
            replace_info = cache

        res = []
        # print(p)
        # print(replace_info)
        for _r in replace_info:
            idx = 0
            r = ""
            if isinstance(_r, List):
                for _p in p:
                    if idx < len(p)-1:
                        r += _p + _r[idx]
                    else:
                        r += _p
                    idx += 1
                res.append(r)
            elif isinstance(_r, str):
                r = p[0] + _r + p[-1]
                res.append(r)
    else:
        res = [txt]
    return res


def convert_to_token_with_cls(txt: str,
                            word2idx: Dict,
                            word2idx_cls: Dict,
                            min_len: int,
                            max_len: int,
                            bpe_sp,
                            token_type: str='chara',
                            to_lower: bool=True,
                            cn_en_symbol: bool=True
                            ) -> Optional[Dict[str, Any]]:
    """
    Args:
        txt: txt
        word2idx: dict for next word predict task
        word2idx_cls: dict for next class predict task, default: {}
        min_len: supported min length of text token
        max_len: supported max length of text token
        bpe_sp: initialized bpe model
        token_type: type of tokenize, choices: char, word_en, bpe, default: bpe
        to_lower: whether convert text to lower case
        cn_en_symbol: whether add cn en switch symbol
    Return:
        Optional[Dict[str, Any]]
        
        keys of dict: "tokens", "class_tags", "use_class_training" 
    """
    unk = "<UNK>"
    not_class_tag = "<other>"
    
    added_dict = extra_dict()
    pattern_cn = re.compile("([\u4e00-\u9fff])")

    if to_lower:
        txt = txt.lower()
        unk = unk.lower()
        not_class_tag = not_class_tag.lower()
        added_dict = [_p.lower() for _p in added_dict]
    else:
        txt = txt.upper()
        unk = unk.upper()
        not_class_tag = not_class_tag.upper()
        added_dict = [_p.upper() for _p in added_dict]
        
    if token_type == 'chara':
        # not support for class training
        tokens = [w for w in txt if w != ' ']
    else:
        chars = pattern_cn.split(txt)
        mix_chars = [w for w in chars if len(w.strip()) > 0]
        if token_type == 'word_en':
            tokens = []
            for cn_or_en in mix_chars:
                if pattern_cn.fullmatch(cn_or_en) is not None:
                    tokens.append(cn_or_en)
                else:
                    for w in cn_or_en.strip().split(' '):
                        if w:
                            tokens.append(w)
        elif token_type == 'bpe':
            tokens = []
            for cn_or_en in mix_chars:
                if pattern_cn.fullmatch(cn_or_en) is not None:
                    tokens.append(cn_or_en)
                else:
                    for w in cn_or_en.strip().split(' '):
                        # 类标签不参与 bpe tokenize
                        if has_tag(w):
                            tokens.append(w)
                        # bpe, and modify some specially words
                        elif w in added_dict:
                            tokens.append("▁"+w)
                        else:
                            for p in bpe_sp.encode_as_pieces(w):
                                tokens.append(p)
    if len(tokens) < min_len or len(tokens) > max_len:
        return None
    
    if cn_en_symbol:
        tokens = add_cn_en_symbol(tokens, to_lower=to_lower)
    
    tokens_id = []
    classes_tag_id = []
    use_class_training = False
    
    not_class_tag_id = word2idx_cls.get(not_class_tag)
    class_tag_id = not_class_tag_id
    
    cache_class_tag = -1
    
    # NOTE 采用BIO标签模式，class_convert_tag用于转换 B 和 I，这里，I = B + 1
    class_convert_tag = False
    
    for token in tokens:
        # 生成多任务中类别预测标签，后续还需要后移一位
        # 过滤类标签
        # example: 打电话给 <pname> pixel </pname>  --> 打电话给 pixel
        if token in word2idx_cls:
            use_class_training = True
            if class_tag_id == word2idx_cls[token] \
                or class_tag_id == cache_class_tag:
                class_tag_id = not_class_tag_id
                class_convert_tag = False
            else:
                class_tag_id = word2idx_cls[token]
                cache_class_tag = class_tag_id
                class_convert_tag = True
            
            continue
            
        if not_class_tag_id is not None and class_tag_id != not_class_tag_id:
            classes_tag_id.append(class_tag_id)
            if class_convert_tag == True:
                class_tag_id += 1
                class_convert_tag = False
        elif not_class_tag_id is not None:
            classes_tag_id.append(class_tag_id)
        
        # 生成输入 token id
        if not word2idx.get(token):
            tokens_id.append(int(word2idx[unk]))
        else:
            tokens_id.append(int(word2idx[token]))

    if len(classes_tag_id) <= 0:
        classes_tag_id = None
    
    return {
        "tokens": tokens_id,
        "class_tags": classes_tag_id,
        "use_class_training": use_class_training
    }
    

def convert_to_token(txt: str,
                     word2idx: Dict,
                     min_len: int,
                     max_len: int,
                     bpe_sp,
                     token_type: str='chara',
                     to_lower: bool=True,
                     cn_en_symbol: bool=True
                     ) -> Optional[Dict[str, Any]]:
    """
    Args:
        txt: txt
        word2idx: dict for next word predict task
        min_len: supported min length of text token
        max_len: supported max length of text token
        bpe_sp: initialized bpe model
        token_type: type of tokenize, choices: char, word_en, bpe, default: bpe
        to_lower: whether convert text to lower case
        cn_en_symbol: whether add cn en switch symbol
    Return:
        Optional[Dict[str, Any]]
        
        keys of dict: "tokens"
    """
    unk = "<UNK>"
    
    added_dict = extra_dict()
    pattern_cn = re.compile("([\u4e00-\u9fff])")

    if to_lower:
        txt = txt.lower()
        unk = unk.lower()
        added_dict = [_p.lower() for _p in added_dict]
    else:
        txt = txt.upper()
        unk = unk.upper()
        added_dict = [_p.upper() for _p in added_dict]
        
    if token_type == 'chara':
        tokens = [w for w in txt if w != ' ']
    else:
        chars = pattern_cn.split(txt)
        mix_chars = [w for w in chars if len(w.strip()) > 0]
        if token_type == 'word_en':
            tokens = []
            for cn_or_en in mix_chars:
                if pattern_cn.fullmatch(cn_or_en) is not None:
                    tokens.append(cn_or_en)
                else:
                    for w in cn_or_en.strip().split(' '):
                        if w:
                            tokens.append(w)
        elif token_type == 'bpe':
            tokens = []
            for cn_or_en in mix_chars:
                if pattern_cn.fullmatch(cn_or_en) is not None:
                    tokens.append(cn_or_en)
                else:
                    for w in cn_or_en.strip().split(' '):
                        # filter class tag, exp: <pname>
                        if has_tag(w):
                            continue
                        # bpe, and modify some specially words
                        elif w in added_dict:
                            tokens.append("▁"+w)
                        else:
                            for p in bpe_sp.encode_as_pieces(w):
                                tokens.append(p)
                                
    if len(tokens) <= min_len:
        return None
    
    if cn_en_symbol:
        tokens = add_cn_en_symbol(tokens, to_lower=to_lower)
        
    if len(tokens) >= max_len:
        return None
    
    tokens_id = []
    
    for token in tokens:
        # 生成输入 token id
        if not word2idx.get(token):
            tokens_id.append(int(word2idx[unk]))
        else:
            tokens_id.append(int(word2idx[token]))
    
    return {
        "tokens": tokens_id
    }


def extra_dict():
    # info = ["magic","magicbook", "plus","earbuds","yoyo","nips2",
    #         "iphone","ipad","airpods",
    #         "mate","matebook","note","oppo","vivo","galaxy",
    #         "wifi","cheese"]
    
    info = ["magic","magicbook", "plus","earbuds","yoyo","nips2",
            "iphone","ipad","airpods",
            "mate","matebook","note","oppo","vivo","galaxy",
            "wifi","cheese",
            "answered","applied","barring","busy","calls","code","connected","dial",
            "dialed","directly","idd","incoming","mobile","operator","overdue",
            "redial","registrations","renew","subscriber","suspended","telephone","zero"]

    return info


if __name__ == "__main__":
    import sys
    sys.path.append("./asr_lm")
    
    from utils.common import load_dict
    
    to_lower = False
    sp = spm.SentencePieceProcessor()
    sp.load("data/bpe_model/spm_giga_xmly_1500.model")
    
    idx2word, word2idx = load_dict("data/dict/lang_cn_char_en_spm_1500_lower_extra2.txt", to_lower=to_lower)
    txt = "hello see you "
    txt = "我想听<music> see you again </music>"
    
    res = convert_to_token(txt=txt,
                           word2idx=word2idx,
                           min_len=0,
                           max_len=150,
                           bpe_sp=sp,
                           token_type="bpe",
                           to_lower=to_lower,
                           cn_en_symbol=True)

    print("res: ", res)
