import os, sys, json
import regex
from pathlib import Path


def contains_chinese(s):
    if regex.search(r"\p{Han}", s):
        return True
    return False


def count_result(base_dir: str, model_name: str, output_dir: str):
    output_path = os.path.join(output_dir, f"{model_name}.json")
    model_result_dir = os.path.join(base_dir, model_name)
    result_p = Path(model_result_dir)

    eval_result = {"answer_type": {}, "difficulties": {"unanswerable": {}}, "Language": {"chinese": {}, "english": {}}}

    def add_key_or_count(key: str, sub_eval_result: dict, right: bool):
        if key in sub_eval_result.keys():
            if right:
                if "right" in sub_eval_result[key].keys():
                    sub_eval_result[key]["right"] += 1
                else:
                    sub_eval_result[key]["right"] = 1
            else:
                if "wrong" in sub_eval_result[key].keys():
                    sub_eval_result[key]["wrong"] += 1
                else:
                    sub_eval_result[key]["wrong"] = 1
        else:
            sub_eval_result[key] = {"right": 0, "wrong": 0}
            if right:
                sub_eval_result[key]["right"] += 1
            else:
                sub_eval_result[key]["wrong"] += 1
        return sub_eval_result

    for result_name in result_p.iterdir():
        result_path = os.path.join(model_result_dir, result_name.name)
        result_json = {}
        try:
            with open(result_path, "r", encoding="utf-8") as f:
                # print(result_path)
                result_json = json.load(f)
        except Exception as e:
            continue
        if not result_json["annotations_result_questions"]["if_can_be_labeled"]:
            continue
        question = result_json["annotations_result_questions"]["question"]
        answer_type = result_json["annotations_result_questions"]["answer_type"]
        difficulties = result_json["annotations_result_questions"]["qa_elements_taxonomy"]
        answerable = result_json["annotations_result_questions"]["question_can_be_answered"]
        if_chinese = contains_chinese(question)
        if "true" in result_json["qa"]["eval_result"]:
            result = True
        elif "false" in result_json["qa"]["eval_result"]:
            result = False
        else:
            continue
        eval_result["answer_type"] = add_key_or_count(
            key=answer_type, sub_eval_result=eval_result["answer_type"], right=result
        )
        for difficulty in difficulties:
            eval_result["difficulties"] = add_key_or_count(
                key=difficulty[-1], sub_eval_result=eval_result["difficulties"], right=result
            )

        if not answerable:
            eval_result["difficulties"] = add_key_or_count(
                key="unanswerable", sub_eval_result=eval_result["difficulties"], right=result
            )
        if if_chinese:
            eval_result["Language"] = add_key_or_count(
                key="chinese", sub_eval_result=eval_result["Language"], right=result
            )
        else:
            eval_result["Language"] = add_key_or_count(
                key="english", sub_eval_result=eval_result["Language"], right=result
            )
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(eval_result, f, ensure_ascii=False, indent=2)


def sort_errors(base_dir: str) -> dict:
    error_analysis = {"answer_type": {}, "difficulties": {}, "language": {}}

    def log_error(category_dict: dict, key: str, model_name: str, file_name: str):
        """A helper function to initialize or update an error entry."""
        if key not in category_dict:
            category_dict[key] = {"error_count": 0, "errors": []}
        category_dict[key]["error_count"] += 1
        category_dict[key]["errors"].append({"model": model_name, "file": file_name})

    base_p = Path(base_dir)
    for model_p in base_p.iterdir():
        if not model_p.is_dir() or model_p.name in ["user", "cot"]:
            continue
        model_name = model_p.name
        for result_file_p in model_p.iterdir():
            if not result_file_p.name.endswith(".json"):
                continue
            try:
                with open(result_file_p, "r", encoding="utf-8") as f:
                    result_json = json.load(f)
            except Exception:
                continue

            if "false" not in result_json.get("qa", {}).get("eval_result", ""):
                continue

            annotations = result_json.get("annotations_result_questions", {})
            if not annotations.get("if_can_be_labeled"):
                continue

            answer_type = annotations.get("answer_type")
            difficulties = annotations.get("qa_elements_taxonomy", [])
            is_answerable = annotations.get("question_can_be_answered")
            question = annotations.get("question", "")

            if answer_type:
                log_error(error_analysis["answer_type"], answer_type, model_name, result_file_p.name)

            for difficulty_path in difficulties:
                # The last element is the specific difficulty key
                difficulty_key = difficulty_path[-1]
                log_error(error_analysis["difficulties"], difficulty_key, model_name, result_file_p.name)

            if not is_answerable:
                log_error(error_analysis["difficulties"], "unanswerable", model_name, result_file_p.name)

            lang_key = "chinese" if contains_chinese(question) else "english"
            log_error(error_analysis["language"], lang_key, model_name, result_file_p.name)

    sorted_analysis = {}
    for category, items in error_analysis.items():
        sorted_items = sorted(items.items(), key=lambda item: item[1]["error_count"], reverse=True)
        sorted_analysis[category] = dict(sorted_items)

    return sorted_analysis


def sort_errors_by_file(base_dir: str) -> dict:
    file_error_analysis = {}
    base_p = Path(base_dir)
    for model_p in base_p.iterdir():
        if not model_p.is_dir() or model_p.name in ["user", "cot"]:
            continue
        model_name = model_p.name
        for result_file_p in model_p.iterdir():
            if not result_file_p.name.endswith(".json"):
                continue
            try:
                with open(result_file_p, "r", encoding="utf-8") as f:
                    result_json = json.load(f)
            except Exception:
                continue
            if "false" not in result_json.get("qa", {}).get("eval_result", ""):
                continue

            annotations = result_json.get("annotations_result_questions", {})
            if not annotations.get("if_can_be_labeled"):
                continue
            file_name = result_file_p.name
            if file_name not in file_error_analysis:
                file_error_analysis[file_name] = {
                    "error_count": 0,
                    "task_id": result_json.get("id", {}).get("task_id"),
                    "annotations_id": result_json.get("id", {}).get("annotations_id"),
                    "answer_type": annotations.get("answer_type"),
                    "difficulties": [path[-1] for path in annotations.get("qa_elements_taxonomy", [])],
                    "is_answerable": annotations.get("question_can_be_answered"),
                    "language": "chinese" if contains_chinese(annotations.get("question", "")) else "english",
                    "errors": [],
                }
            file_error_analysis[file_name]["error_count"] += 1
            file_error_analysis[file_name]["errors"].append(model_name)
    sorted_items = sorted(file_error_analysis.items(), key=lambda item: item[1]["error_count"], reverse=True)
    sorted_analysis = dict(sorted_items)
    return sorted_analysis


if __name__ == "__main__":
    result_dir = "project/chartqa/result/cot"
    # output_dir = "project/chartqa/result/user/eval_result"
    # output_dir = "project/chartqa/result/user/cot"
    output_dir = "project/chartqa/result/user/error_analysis"

    # os.makedirs(output_dir, exist_ok=True)
    # all_annotation_result_path = "project/chartqa/data/label_studio/data_annotated/processed_annotations/all_annotations.json"
    # result_p = Path(result_dir)
    # for model_name in result_p.iterdir():
    #     if model_name.is_dir() and model_name.name not in ["user", "cot"]:
    #         count_result(base_dir=result_dir, model_name=model_name.name, output_dir=output_dir)

    # output_path = os.path.join(output_dir, "sorted_error.json")
    # sorted_errors = sort_errors(result_dir)
    # with open(output_path, "w", encoding="utf-8") as f:
    #     json.dump(sorted_errors, f, ensure_ascii=False, indent=2)

    output_path = os.path.join(output_dir, "file_error.json")
    sorted_errors = sort_errors_by_file(result_dir)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(sorted_errors, f, ensure_ascii=False, indent=2)
