import json
import os
import statistics


def count_results(data_path):
    # tokenizer = AutoTokenizer.from_pretrained(
    #     "models/Qwen2.5-VL-72B-Instruct", trust_remote_code=True
    # )
    subchart_counts = []
    related_chart_counts = []
    count_result = {
        "task": 0,
        "qa": 0,
        "charts": 0,
        "have_subchart": 0,
        "num_of_charts_within_have_subchart": 0,
        "subchart_stats": {},
        "relatedchart_stats": {},
        "subcharts": 0,
        "chart_types": {},
        # "context_tokens": 0,
        "failed": [],
        "failed_id": [],
    }
    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)
        for task in data:  # each task
            try:
                count_result["task"] += 1
                if task["annotations_result_questions"]["if_can_be_labeled"] == True:
                    count_result["qa"] += 1
                for key in task["annotations_result_figures"].keys():
                    figure = task["annotations_result_figures"][key]
                    count_result["charts"] += 1

                    for chart_type in figure["chart_type"]:
                        if chart_type in count_result["chart_types"].keys():
                            count_result["chart_types"][chart_type] += 1
                        else:
                            count_result["chart_types"][chart_type] = 1

                    if "if_sub_charts" in figure.keys():
                        if figure["if_sub_charts"]:
                            count_result["have_subchart"] += 1
                            if "number_of_sub_chart" in figure.keys():
                                count_result["subcharts"] += figure["number_of_sub_chart"]
                                count_result["num_of_charts_within_have_subchart"] += figure["number_of_sub_chart"]
                                subchart_counts.append(figure["number_of_sub_chart"])
                            else:
                                count_result["failed"].append(task["id"]["pub_id"])
                                count_result["failed_id"].append(task["id"]["task_id"])
                        else:
                            count_result["subcharts"] += 1
                    else:
                        count_result["failed"].append(task["id"]["pub_id"])
                        count_result["failed_id"].append(task["id"]["task_id"])
                        count_result["subcharts"] += 1
                # for chart_name in task["annotations_result_questions"]["charts_used"]["names"]:
                #     related_chart_counts.append(
                #         task["annotations_result_figures"][chart_name].get("number_of_sub_chart", 1)
                #     )
                if task["annotations_result_questions"]["charts_used"]["nums"] > 2:
                    related_chart_counts.append(task["annotations_result_questions"]["charts_used"]["nums"])
            except Exception as e:
                count_result["failed"].append(task["id"]["pub_id"])
                count_result["failed_id"].append(task["id"]["task_id"])
    if subchart_counts:
        count_result["subchart_stats"] = {
            "max": max(subchart_counts),
            "min": min(subchart_counts),
            "mean": statistics.mean(subchart_counts),
            "median": statistics.median(subchart_counts),
        }
    if related_chart_counts:
        count_result["relatedchart_stats"] = {
            "max": max(related_chart_counts),
            "min": min(related_chart_counts),
            "mean": statistics.mean(related_chart_counts),
            "median": statistics.median(related_chart_counts),
        }
    print(count_result)
    return count_result


if __name__ == "__main__":
    data_path = "project/chartqa/data/label_studio/data_annotated/processed_annotations/all_annotations.json"

    count_results(data_path)
