import json

class Judger:
    # Initialize Judger, with the path of tag list
    def __init__(self, tag_path):
        self.tag_dic = {}
        f = open(tag_path, "r", encoding='utf-8')
        self.task_cnt = 0
        for line in f:
            # print(line)
            self.task_cnt += 1
            self.tag_dic[line[:-1]] = self.task_cnt
        # print(self.tag_dic)

    # Format the result generated by the Predictor class
    @staticmethod
    def format_result(result):
        rex = {"tags": []}
        res_art = []
        for x in result["tags"]:
            if not (x is None):
                res_art.append(int(x))
        rex["tags"] = res_art

        return rex

    # Gen new results according to the truth and users output
    def gen_new_result(self, result, truth, label):
        s1 = set()
        for tag in label:
            s1.add(self.tag_dic.setdefault(tag.replace(' ', ''), None))
        s2 = set()
        for name in truth:
            s2.add(self.tag_dic.setdefault(name.replace(' ', ''), None))

        for a in range(0, self.task_cnt):
            in1 = (a + 1) in s1
            in2 = (a + 1) in s2
            if in1:
                if in2:
                    result[0][a]["TP"] += 1
                else:
                    result[0][a]["FP"] += 1
            else:
                if in2:
                    result[0][a]["FN"] += 1
                else:
                    result[0][a]["TN"] += 1

        return result

    # Calculate precision, recall and f1 value
    # According to https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
    @staticmethod
    def get_value(res):
        if res["TP"] == 0:
            if res["FP"] == 0 and res["FN"] == 0:
                precision = 1.0
                recall = 1.0
                f1 = 1.0
            else:
                precision = 0.0
                recall = 0.0
                f1 = 0.0
        else:
            precision = 1.0 * res["TP"] / (res["TP"] + res["FP"])
            recall = 1.0 * res["TP"] / (res["TP"] + res["FN"])
            f1 = 2 * precision * recall / (precision + recall)

        return precision, recall, f1

    # Generate score
    def gen_score(self, arr):
        sumf = 0
        y = {"TP": 0, "FP": 0, "FN": 0, "TN": 0}
        for x in arr[0]:
            p, r, f = self.get_value(x)
            sumf += f
            for z in x.keys():
                y[z] += x[z]

        _, __, f_ = self.get_value(y)
        macro_f = sumf * 1.0 / len(arr[0])
        micro_f = f_

        return {"macro" : macro_f, "micro" : micro_f}

    # Test with ground truth path and the user's output path
    def test(self, truth_path, output_path):
        cnt = 0
        result = [[]]
        for a in range(0, self.task_cnt):
            result[0].append({"TP": 0, "FP": 0, "TN": 0, "FN": 0})

        # with open(truth_path, "r", encoding='utf-8') as inf, open(output_path, "r", encoding='utf-8') as ouf:
        ground_doc_dict = {}
        with open(truth_path, "r", encoding='utf-8') as inf:
            for line in inf:
                ground_doc = json.loads(line)
                ah = ground_doc[0]['ah']
                ground_doc_dict[ah] = ground_doc
                
        with open(output_path, "r", encoding='utf-8') as inf:
            for line in inf:
                user_doc = json.loads(line)
                ah = user_doc[0]['ah']
                if ah in ground_doc_dict:
                    ground_doc = ground_doc_dict[ah]
                else:
                    print("WARNING: ah", ah, "is not in ground truth file")
                    continue
                for ind in range(len(ground_doc)):
                    ground_truth = ground_doc[ind]['label']
                    try:
                        user_output = user_doc[ind]['label']
                    except:
                        print(user_doc[ind])
                    cnt += 1
                    result = self.gen_new_result(result, ground_truth, user_output)

        return result

# Generatue final_score
def get_score(truth_path_labor, output_path_labor, tag_path_labor):
    def ret_operate_helper(info, ret, tag, score):
        if tag == "total":
            for flag in ["macro", "micro"]:
                name = "%s_%s_f1" % (tag, flag)
                info += "%14s : %10f\n" % (name, score[flag])
                ret.append((name, score[flag]))
        else:
            flag = "macro"
            name = "%s_f1" % (tag)
            info += "%14s : %10f\n" % (name, score[flag])
            ret.append((name, score[flag]))
        return info, ret

    judger_labor = Judger(tag_path=tag_path_labor)
    reslt_labor = judger_labor.test(truth_path=truth_path_labor,
                                    output_path=output_path_labor)
    tag_dic = list(judger_labor.tag_dic)
    ret = []
    ret_str = ""
    for idx, d in enumerate(reslt_labor[0]):
        f1 = judger_labor.gen_score([[d]])
        ret_str, ret = ret_operate_helper(ret_str, ret, tag_dic[idx], f1)

    total_f1 = judger_labor.gen_score(reslt_labor)
    ret_str, ret = ret_operate_helper(ret_str, ret, "total", total_f1)

    return ret, ret_str[:-1]

if __name__ == '__main__':
    final_score = get_score(truth_path_labor='test_data/10_new.json',
                            output_path_labor='test_data/abl_predict_0.json',
                            tag_path_labor='test_data/tags_for_test.txt')
    print(final_score)
    print(final_score[1])
