import os, json
from transformers import AutoTokenizer


def count_results(data_path):
    tokenizer = AutoTokenizer.from_pretrained("models/Qwen2.5-VL-72B-Instruct", trust_remote_code=True)

    count_result = {
        "task": 0,
        "qa": 0,
        "qa_sub_charts": 0,
        "unanswerable": 0,
        "context": 0,
        "question_tokens": 0,
        "answer_tokens": 0,
        "context_tokens": 0,
        "failed": [],
        "failed_id": [],
    }
    context_list = {
        "question_tokens": [],
        "answer_tokens": [],
        "context_tokens": [],
    }

    with open(data_path, "r") as f:
        data = json.load(f)
        for task in data:  # each task
            try:
                question = ""
                answer = ""
                caption = ""
                count_result["task"] += 1
                if task["annotations_result_questions"]["if_can_be_labeled"] == True:
                    count_result["qa"] += 1
                    question = task["annotations_result_questions"]["question"]
                    answerable = task["annotations_result_questions"]["question_can_be_answered"]
                    if answerable:
                        answer = task["annotations_result_questions"]["answer"]
                    else:
                        count_result["unanswerable"] += 1
                        answer = "Unanswerable"

                for key in task["origin_data"]["image_captions"].keys():
                    caption += task["origin_data"]["image_captions"][key].get("caption", "")
                    if task["origin_data"]["image_captions"][key].get("caption", "") != "":
                        count_result["context"] += 1
                count_result["question_tokens"] += len(tokenizer.encode(question))
                count_result["answer_tokens"] += len(tokenizer.encode(answer))
                count_result["context_tokens"] += len(tokenizer.encode(caption))
                context_list["question_tokens"].append(len(tokenizer.encode(question)))
                context_list["answer_tokens"].append(len(tokenizer.encode(answer)))
                context_list["context_tokens"].append(len(tokenizer.encode(caption)))
                for chart_name in task["annotations_result_questions"]["charts_used"]["names"]:
                    count_result["qa_sub_charts"] += task["annotations_result_figures"][chart_name].get(
                        "number_of_sub_chart", 1
                    )
            except Exception as e:
                count_result["failed"].append(task["id"]["pub_id"])
                count_result["failed_id"].append(task["id"]["task_id"])

    print(count_result)
    return count_result, context_list


if __name__ == "__main__":
    data_path = "project/chartqa/data/label_studio/data_annotated/processed_annotations/all_annotations.json"

    count_result, context_list = count_results(data_path)

    with open("project/chartqa/src/analysis/context_list.json", "w") as f:
        json.dump(context_list, f, ensure_ascii=False, indent=2)
