import sys, os, shutil
import copy
import re
import json
from tqdm import tqdm
from enum import Enum, auto
from pathlib import Path
import collections
from collections import defaultdict
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable

plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False


def count_answer_type_difficulty(data):
    counts = defaultdict(int)
    answer_types = set()
    difficulties = set()
    for item in data:
        question_info = item["annotations_result_questions"]
        raw_ans_type = question_info["answer_type"]

        if isinstance(raw_ans_type, list):
            if not raw_ans_type:
                continue  # 跳过空列表
            ans_type = raw_ans_type[0]
        else:
            ans_type = raw_ans_type

        answer_types.add(ans_type)

        for taxonomy_path in question_info["qa_elements_taxonomy"]:
            difficulty = " - ".join(taxonomy_path)
            difficulties.add(difficulty)
            counts[(ans_type, difficulty)] += 1

    if not counts:
        print("no qa")
        return

    sorted_answer_types = sorted(list(answer_types))
    sorted_difficulties = sorted(list(difficulties))

    heatmap_data = pd.DataFrame(0, index=sorted_answer_types, columns=sorted_difficulties)
    for (ans_type, difficulty), count in counts.items():
        heatmap_data.loc[ans_type, difficulty] = count
    difficulity_sums = heatmap_data.sum(axis=0)
    answer_type_sums = heatmap_data.sum(axis=1)
    print(difficulity_sums)
    print(answer_type_sums)


def count_difficulty_level(data):
    difficulty_levels = []
    for item in data:
        annotations = item.get("annotations_result_questions", {})
        if annotations.get("if_can_be_labeled"):
            difficulty = annotations.get("difficulty")
            if difficulty is not None:
                difficulty_levels.append(difficulty)

    if not difficulty_levels:
        print("没有找到有效的难度等级数据，无法生成图表。")
        return

    counts = pd.Series(difficulty_levels).value_counts().sort_index()
    print(counts)


if __name__ == "__main__":
    data_path = "project/chartqa/data/label_studio/data_annotated/processed_annotations/all_annotations.json"

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)
        count_answer_type_difficulty(data)
        count_difficulty_level(data)
