from char_smi import CharFuncs
from collections import namedtuple
from pypinyin import pinyin, Style
import os
Correction = namedtuple(
    "Correction",
    [
        "op",
        "toks",
        "inds",
    ],
) 
cache_dir = os.environ.get('COMPASS_DATA_CACHE', '')
char_smi = CharFuncs(os.path.join(cache_dir, "data", "lawbench", "eval_assets", "char_meta.txt"))

def check_spell_error(src_span: str,
                      tgt_span: str,
                      threshold: float = 0.8) -> bool:
    if len(src_span) != len(tgt_span):
        return False
    src_chars = [ch for ch in src_span]
    tgt_chars = [ch for ch in tgt_span]
    if sorted(src_chars) == sorted(tgt_chars):  # 词内部字符异位
        return True
    for src_char, tgt_char in zip(src_chars, tgt_chars):
        if src_char != tgt_char:
            if src_char not in char_smi.data or tgt_char not in char_smi.data:
                return False
            v_sim = char_smi.shape_similarity(src_char, tgt_char)
            p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
            if v_sim + p_sim < threshold and not (
                    set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
                return False
    return True

class Classifier:
    """
    错误类型分类器
    """
    def __init__(self,
                 granularity: str = "word"):

        self.granularity = granularity

    @staticmethod
    def get_pos_type(pos):
        if pos in {"n", "nd"}:
            return "NOUN"
        if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
            return "NOUN-NE"
        if pos in {"v"}:
            return "VERB"
        if pos in {"a", "b"}:
            return "ADJ"
        if pos in {"c"}:
            return "CONJ"
        if pos in {"r"}:
            return "PRON"
        if pos in {"d"}:
            return "ADV"
        if pos in {"u"}:
            return "AUX"
        # if pos in {"k"}:  # TODO 后缀词比例太少，暂且分入其它
        #     return "SUFFIX"
        if pos in {"m"}:
            return "NUM"
        if pos in {"p"}:
            return "PREP"
        if pos in {"q"}:
            return "QUAN"
        if pos in {"wp"}:
            return "PUNCT"
        return "OTHER"

    def __call__(self,
                 src,
                 tgt,
                 edits,
                 verbose: bool = False):
        """
        为编辑操作划分错误类型
        :param src: 错误句子信息
        :param tgt: 正确句子信息
        :param edits: 编辑操作
        :param verbose: 是否打印信息
        :return: 划分完错误类型后的编辑操作
        """
        results = []
        src_tokens = [x[0] for x in src]
        tgt_tokens = [x[0] for x in tgt]
        for edit in edits:
            error_type = edit[0]
            src_span = " ".join(src_tokens[edit[1]: edit[2]])
            tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
            # print(tgt_span)
            cor = None
            if error_type[0] == "T":
                cor = Correction("W", tgt_span, (edit[1], edit[2]))
            elif error_type[0] == "D":
                if self.granularity == "word":  # 词级别可以细分错误类型
                    if edit[2] - edit[1] > 1:  # 词组冗余暂时分为OTHER
                        cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
                    else:
                        pos = self.get_pos_type(src[edit[1]][1])
                        pos = "NOUN" if pos == "NOUN-NE" else pos
                        pos = "MC" if tgt_span == "[缺失成分]" else pos
                        cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
                else:  # 字级别可以只需要根据操作划分类型即可
                    cor = Correction("R", "-NONE-", (edit[1], edit[2]))
            elif error_type[0] == "I":
                if self.granularity == "word":  # 词级别可以细分错误类型
                    if edit[4] - edit[3] > 1:  # 词组丢失暂时分为OTHER
                        cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
                    else:
                        pos = self.get_pos_type(tgt[edit[3]][1])
                        pos = "NOUN" if pos == "NOUN-NE" else pos
                        pos = "MC" if tgt_span == "[缺失成分]" else pos
                        cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
                else:  # 字级别可以只需要根据操作划分类型即可
                    cor = Correction("M", tgt_span, (edit[1], edit[2]))
            elif error_type[0] == "S":
                if self.granularity == "word":  # 词级别可以细分错误类型
                    if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
                        cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
                        # Todo 暂且不单独区分命名实体拼写错误
                        # if edit[4] - edit[3] > 1:
                        #     cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
                        # else:
                        #     pos = self.get_pos_type(tgt[edit[3]][1])
                        #     if pos == "NOUN-NE":  # 命名实体拼写有误
                        #         cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
                        #     else:  # 普通词语拼写有误
                        #         cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
                    else:
                        if edit[4] - edit[3] > 1:  # 词组被替换暂时分为OTHER
                            cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
                        else:
                            pos = self.get_pos_type(tgt[edit[3]][1])
                            pos = "NOUN" if pos == "NOUN-NE" else pos
                            pos = "MC" if tgt_span == "[缺失成分]" else pos
                            cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
                else:  # 字级别可以只需要根据操作划分类型即可
                    cor = Correction("S", tgt_span, (edit[1], edit[2]))
            results.append(cor)
        if verbose:
            print("========== Corrections ==========")
            for cor in results:
                print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
        return results

# print(pinyin("朝", style=Style.NORMAL))
