# import re
from itertools import combinations
import numpy as np
import torch
from checkrules import CheckRules


def check_attr(attr, matched):
    if matched[0]==1 and matched[11]==0 and attr[0]==0:
        return False
    if matched[0]==0 and attr[0]==1:
        return False
    if matched[1] != attr[1]:
        return False
    if matched[2]==0 and attr[2]==1:
        return False
    if matched[2]==1 and matched[10]==0 and attr[2]==0:
        return False
    if matched[3]==1 and matched[10]==0 and attr[3]==0:
        return False
    if matched[3]==0 and attr[3]==1:
        return False
    if matched[4]==1 and matched[12]==1 and attr[4]==0:
        return False
    if matched[4]==0 and attr[4]==1:
        return False
    if matched[5] != attr[5]:
        return False
    if matched[7]==1 and attr[6]==0:
        return False
    if matched[8]==1 and attr[7]==0:
        return False
    if matched[9]==1 and attr[0]==0 and attr[1]==0 and attr[2]==0 and attr[4]==0 and attr[5]==0:
        return False
    return True


def attr_convert(attr, pos):
    attr_copy = attr.copy()
    for idx in pos:
        attr_copy[idx] = 1 - attr_copy[idx]
    return attr_copy

class SentenceAbduction:
    def set_predict_model(self, model):
        self.model = model

    def get_matching_re(self, context):
        vec = [0] * len(self.facter_strs)
        if context is None:
            print("Context can not be found!")
            return vec

        for i in range(len(self.facter_strs)):
            facter_str = self.facter_strs[i]
            for facter in facter_str:
                # res = re.search(facter, context)
                # if res is not None:
                #     loc = res.span()[0]
                loc = context.find(facter)
                if loc != -1:
                    if i == 7 or i == 8:
                        if abs(context.find(self.not_facter_room_theft_str)-loc) <= 20:
                            vec[i] = 0
                        else:
                            vec[i] = 1
                    elif i == 0:
                        if context[loc-1]=='未':
                            vec[i] = 0
                        else:
                            vec[i] = 1
                    else: 
                        vec[i] = 1
        return vec

    def get_penalty_type(self, money, attrs):
        [no_damage, attitude, surrender, again, young, forgive, room, theft] = attrs
        if money < self.LARGE:
            if room == 1 or theft == 1:
                return 0
            if money >= 500 and again == 1:
                return 0
        elif money < self.HUGE:
            return 0
        elif money < self.EXTRA_HUGE:
            return 1
        else:
            return 2
        return -1

    def validate(self, money, attrs, month):
        [no_damage, attitude, surrender, again, young, forgive, room, theft] = attrs
        prob = self.check.judge(attrs)
        if prob < 1e-6:
            return False

        penalty_type = self.get_penalty_type(money, attrs)
        if penalty_type == -1:
            return False
        if penalty_type == 0:
            if month <= 3 * 12:
                return True
            elif month >= 3 * 12 and month <= 10 * 12 and room == 1 and money >= 15000:
                return True
            else:
                return False
        if penalty_type == 1:
            if month >= 3 * 12 and month <= 10 * 12:
                return True
            if month <= 3 * 12 and surrender == 1:
                return True
            if month >= 10 * 12 and room == 1 and money >= 150000:
                return True
        if penalty_type == 2:
            if month >= 10 * 12:
                return True
            if month <= 10 * 12 and month >= 3 * 12 and (surrender == 1 or young == 1):
                return True
        return False

    def predict_and_validate(self, X, attrs, target_month):
        Y = self.model.predict(X, attrs)
        for i in range(len(X)):
            [money] = X[i]
            if self.validate(money, attrs[i], target_month[i]) == False:
                Y[i] = -1
        return Y

    def select_abduced_result(self, attrs, months, target_month):
        if len(attrs) == 0:
            return None
        assert len(attrs) > 0
        candidates = []
        for attr, month in zip(attrs, months):
            if month < 0:
                continue
            candidates.append((abs(month - target_month), attr))
        candidates.sort()
        # TODO softmax temperature
        select_prob = torch.softmax(torch.tensor([it[0] for it in candidates]), -1).numpy()
        select_idx = int(np.random.choice(np.arange(len(candidates), dtype=int), size=1, p=select_prob))
        selected_attr = candidates[select_idx][1]
        return selected_attr

    def abduce_npos(self, money, attr, target_month, gt, n, match_res, debug):
        pos_list = list(combinations(range(len(attr)), n))
        candidates, abduced_months = [], []
        for pos in pos_list:
            new_attr = attr_convert(attr, pos)
            if self.word_match and check_attr(new_attr, match_res) == False:
                if debug:
                    if gt == new_attr:
                        print("check attr rej", new_attr)
                continue
            [predicted_month] = self.predict_and_validate([money], [new_attr], [target_month])
            if predicted_month == -1:
                if debug:
                    if gt == new_attr:
                        print("valid attr rej", new_attr)
                continue
            candidates.append(new_attr.copy())
            abduced_months.append(predicted_month)
        return candidates, abduced_months

    def __init__(self, model, rule_file_path, word_match=False, strong_conf=False):
            # self.label2id ={"no_damage": 0, "attitude": 1, "surrender": 2, "again": 3, "young": 4, "forgive": 5, "tool": 6, "indoor": 7, "theft": 8}
            self.label2id ={"no_damage": 0, "attitude": 1, "surrender": 2, "again": 3, "young": 4, "forgive": 5, "indoor": 6, "theft": 7}
            self.id2label = {value:key for key, value in self.label2id.items()}
            self.LARGE = 1000
            self.HUGE = 30000
            self.EXTRA_HUGE = 300000

            self.check = CheckRules(rule_file_path)
            self.model = model
            self.word_match = word_match
            self.strong_conf = strong_conf
            self.remove_thresh = 0.9
            self.add_thresh = 0.1

            facter_damage_str = ["追回", "退还", "没有给被害人造成损失", "赔偿", "发还", "退缴", "退赔", "追还", "退赃", "返还", "未给被害人造成经济损失","归还"]
            facter_attitude_str = ["如实供述", "主动交代", "认罪", "悔罪", "坦白", "如实交待"]
            facter_surrender_str = ["自首"]
            facter_again_str = ["因犯", "曾因", "累犯", "前科"]
            facter_young_str = ["未成年", "未满十八周岁", '未满18周岁', '不满十八周岁', '不满18周岁']
            facter_forgive_str = ["谅解", "原谅"]
            # facter_tool_str = ["作案工具", "盗窃用的工具", "盗窃用工具"]
            facter_tool_str = []
            facter_room_str = ["入室", "入户", "家中", "卧室"]
            # facter_room_str = ["入室", "入户", "病房",r"入[^，。]+家中[^被搜]",r"入[^，。]*屋内",r"入[^，。]*住宅[^罪]",r"(入|在)[^，。]*住房",r"(入|在)[^，。]*宿舍",r"房间.{0,20}睡",r"屋内.{0,20}睡",r"宾馆.{0,20}睡",]
            facter_theft_str = ["扒窃", "扒取", "扒走", "上衣", "口袋", "衣兜"]
            facter_less_str = ["从轻处罚", "减轻处罚"]
            facter_neg_str = ["不予采纳","不具有","不符","不构成","不属","不认定","不予认定"]
            self.order_facter_damage_str = ["责令"]
            facter_closed_str = ["不公开开庭"]
            self.facter_strs = [facter_damage_str, facter_attitude_str, facter_surrender_str, facter_again_str, facter_young_str, facter_forgive_str, facter_tool_str, facter_room_str, facter_theft_str, facter_less_str, facter_neg_str, self.order_facter_damage_str, facter_closed_str]
            
            self.not_facter_room_theft_str = "多次盗窃、入户盗窃、携带凶器盗窃、扒窃的"
            self.not_facter_surrender_str1 = "犯罪以后自动投案，如实供述自己的罪行的"
            self.not_facter_surrender_str2 = "对于自首的犯罪分子，可以从轻或者减轻处罚"
            
    def __abduce(self, money, attr, target_month, context, gt, max_change_num, lvl, debug):
        abduced_attrs, abduced_months = [], []
        if self.word_match:
            match_res = self.get_matching_re(context)
            if debug:
                print("match", match_res)
        else:
            match_res = None
        for change_num in range(max_change_num + 1):
            abd_attrs, abd_months = self.abduce_npos(money, attr, target_month, gt, change_num, match_res, debug)
            for a, m in zip(abd_attrs, abd_months):
                assert m != -1
                if np.sum(a) > lvl:  # Level abduction
                    continue
                abduced_attrs.append(a)
                abduced_months.append(m)
        abduced_attr = self.select_abduced_result(abduced_attrs, abduced_months, target_month)
        change = not(abduced_attr is None or attr == abduced_attr)
        return abduced_attr, change

    def abduce(self, money, attr, month, prob, context, gt, max_change_num, lvl, debug=False):
        assert len(attr) == len(prob)
        abduced_attr, change = self.__abduce(money, attr, month, context, gt, max_change_num, lvl, debug)

        if self.strong_conf and change:
            assert len(abduced_attr) == len(attr)
            for i, old_a in enumerate(attr):
                old_prob = prob[i]
                new_a = abduced_attr[i]
                if old_a == new_a:
                    continue
                if new_a == 0:  # remove
                    if old_prob > self.remove_thresh:  # old prob too large
                        abduced_attr[i] = 1  # do not remove
                elif new_a == 1:  # add
                    if old_prob < self.add_thresh:  # old prob too small
                        abduced_attr[i] = 0  # do not add
                else:
                    assert False
        return abduced_attr

    def abduce_batch(self, ahs, moneys, attrs, months, probs, origin_contexts, gts, max_change_num, lvl):
        abduced_ahs, abduced_attrs = [], []
        for ah, money, attr, month, prob, context, gt in zip(ahs, moneys, attrs, months, probs, origin_contexts, gts):
            abduced_attr = self.abduce(money, attr, month, prob, context, gt, max_change_num, lvl)
            if abduced_attr is None and self.word_match:
                self.word_match = False
                abduced_attr = self.abduce(money, attr, month, prob, context, gt, max_change_num, lvl, True)
                self.word_match = True
                if abduced_attr is None:
                    continue
            abduced_ahs.append(ah)
            abduced_attrs.append(abduced_attr)
        return abduced_ahs, abduced_attrs

    def abduce_data(self, data, train_ahs, max_change_num, lvl):
        ahs, moneys, attrs, months, probs, contexts, gts = [], [], [], [], [], [], []
        for datapoint in data:
            if datapoint["ah"] not in train_ahs:
                continue
            ahs.append(datapoint["ah"])
            moneys.append([datapoint["money"]])
            attrs.append(datapoint["pred_attr"])
            months.append(datapoint["month"])
            probs.append(datapoint["pred_prob"])
            contexts.append(datapoint["sentence"])  # original sentence
            gts.append(datapoint["attr"])
        abduced_ahs, abduced_attrs = self.abduce_batch(ahs, moneys, attrs, months, probs, contexts, gts, max_change_num, lvl)

        tmp = { ah : aa for ah, aa in zip(abduced_ahs, abduced_attrs) }
        for datapoint in data:
            if datapoint["ah"] not in abduced_ahs:
                continue
            datapoint["abduced_attr"] = tmp[datapoint["ah"]]
        return abduced_ahs

    # def ad_hoc_test(self, money, attr, month, context, max_change_num):
    #     return self.__abduce(money, attr, month, context, max_change_num)

