# Import the SDK and the client module
import sys, os, shutil
import copy
import re
import json
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from label_studio_sdk.client import LabelStudio
from functools import partial
from enum import Enum, auto
from pathlib import Path

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gpt_api.unigpt import GPT
from gpt_tools.deepseek_tools import DeepseekChat

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from common_prompts import get_topic, qa_elements_ch2eng, ch2eng


data_struct = {
    "id": {
        "pub_id": "",
        "task_id": "",
        "annotations_id": "",
    },
    "origin_data": {
        "image_captions": [
            {
                "caption": "",
                "image_path": "/data/local-files/?d=spiqa/SPIQA_train_val/images/1612.01810v3/1612.01810v3-Figure6-1.png",
                "image_name": "1612.01810v3-Figure6-1.png",
            },
        ],
        "paper_id": "1612.01810v3",
    },
    "meat_data": {
        "topic": [""],
    },
    "annotations_result_figures": {
        "1612.01810v3-Figure6-1.png": {  # "from_name": "image_name_0",
            "index": 0,
            "chart_type": [],  # "from_name": "chart_type_0",
            "if_sub_charts": None,  # "from_name": "sub_chart_0",
            "number_of_sub_chart": 1,  # "from_name": "number_sub_chart_1",
            "caption_translation": "",  # "from_name": "translation_to_caption_0",
        }
    },
    "annotations_result_questions": {
        "if_can_be_labeled": None,  # "from_name": "if_can_be_labeled_question",
        "difficulty": 0,  # "from_name": "difficulty_of_the_question",
        "question_type": [],  # "from_name": "question_type",
        "answer_type": [],  # "from_name": "answer_type",
        "qa_elements_taxonomy": [],  # "from_name": "qa_elements_taxonomy",
        "question_can_be_answered": None,  # "from_name": "can be answered",
        "charts_used": {"nums": 0, "names": [""]},  # "from_name": "charts_used",
        "question": "",  # "from_name": "question",
        "answer": "",  # "from_name": "answer",
    },
}
annotation_figure_struct = {
    "0": {
        "name": "1612.01810v3-Figure6-1.png",
        "chart_type": [],
        "if_sub_charts": None,
        "number_of_sub_chart": 1,
        "caption_translation": "",
    }
}


def remove_image_path_prefix(image_path: str):
    """'/data/local-files/?d=image_path' to 'image_path'

    Parameters
    ----------
    image_path : str
        string of labelstudio image path

    Returns
    -------
    list
        image name with out prefix
    """
    return image_path.split("/data/local-files/?d=")[-1]


def remove_images_path_prefix(image_captions: list):
    """'/data/local-files/?d=image_path' to 'image_path'

    Parameters
    ----------
    image_captions : list
        list of image_caption

    Returns
    -------
    list
        image name list with out prefix
    """
    cleaned_image_captions = image_captions
    for i in range(len(image_captions)):
        # cleaned_image_captions[i]["image"] = image_captions[i]["image"].split("/data/local-files/?d=")[-1]
        cleaned_image_captions[i]["image"] = remove_image_path_prefix(image_captions[i]["image"])
    return cleaned_image_captions


def append_processed_image_data(processed_image_data: dict, image_index: str, key: str, value):
    """help to copy value from annotions to proessed image dict

    Parameters
    ----------
    processed_image_data : dict
        proessed image data
    image_index : str
        key of image index in annotions json
    key : str
        key
    value : _type_
        value

    Returns
    -------
    dict
        proessed image dict
    """
    if image_index in processed_image_data.keys():
        processed_image_data[image_index][key] = value
    else:
        processed_image_data[image_index] = {}
        processed_image_data[image_index][key] = value
    return processed_image_data


def transfor_processed_image_data2processed_data(processed_image_data):
    """merge image data(chart_type, if_sub_charts, number_of_sub_chart, caption_translation) to output_json["annotations_result_figures"]

    Parameters
    ----------
    processed_image_data : dict
        image data

    Returns
    -------
    dict
        annotations_result_figures
    """
    annotations_result_figures = {}
    for key in processed_image_data.keys():
        int_key = int(key)
        image_name = processed_image_data[key]["name"]
        annotations_result_figures[image_name] = {}
        annotations_result_figures[image_name]["index"] = int_key
        if "chart_type" in processed_image_data[key].keys():
            annotations_result_figures[image_name]["chart_type"] = processed_image_data[key]["chart_type"]
        if "if_sub_charts" in processed_image_data[key].keys():
            annotations_result_figures[image_name]["if_sub_charts"] = processed_image_data[key]["if_sub_charts"]
        if "number_of_sub_chart" in processed_image_data[key].keys():
            annotations_result_figures[image_name]["number_of_sub_chart"] = processed_image_data[key][
                "number_of_sub_chart"
            ]
        if "caption_translation" in processed_image_data[key].keys():
            annotations_result_figures[image_name]["caption_translation"] = processed_image_data[key][
                "caption_translation"
            ]
    return annotations_result_figures


def post_process_processed_data(processed_data, processed_image_data):
    """merge extra image name of used information to output dict and change key of origin image data to "image_name"

    Parameters
    ----------
    processed_data : dict
        output dict
    processed_image_data : dict
        image data dict

    Returns
    -------
    dict
        output dict
    """
    chart_indexs = processed_data["annotations_result_questions"]["charts_used"]["names"]
    chart_names = []
    for chart_index in chart_indexs:
        if str(chart_index).isdigit():
            chart_names.append(processed_image_data[f"{chart_index}"]["name"])
    processed_data["annotations_result_questions"]["charts_used"]["names"] = chart_names

    image_captions = {}
    for ic in processed_data["origin_data"]["image_captions"]:
        image_captions[ic["image_name"]] = {}
        image_captions[ic["image_name"]]["image"] = ic["image"]
        image_captions[ic["image_name"]]["caption"] = ic["caption"]

    processed_data["origin_data"]["image_captions"] = image_captions

    return processed_data


def get_correct_annotations_id(annotations: list) -> tuple[int, int]:
    """find correct annotations id from annotations, which contains auto-generated annotations

    Parameters
    ----------
    annotations : list
        annotations

    Returns
    -------
    tuple[int, int]
        annotations_id, annotations_index(annotation["annotations"][idx])
    """
    annotations_id = -1
    annotations_index = -1
    for annotation in annotations:
        annotations_index += 1
        results = annotation["result"]
        for result in results:
            if "from_name" in result.keys():
                if ("if_can_be_labeled_question" == result["from_name"]) or ("question" == result["from_name"]):
                    annotations_id = annotation["id"]
                    return annotations_id, annotations_index
    return annotations_id, annotations_index


def process_origin_tasks_jsons(
    origin_tasks: dict,
    data_struct: dict,
    annotation_figure_struct: dict,
    agent,
    annotation_root: str,
    if_english_dir: bool,
):
    """annotations to dataset format

    Parameters
    ----------
    origin_tasks : dict
        annotation task from label studio
    data_struct : dict
        output dataset struct
    annotation_figure_struct : dict
        annotation figures data struct

    Returns
    -------
    dict
        output data dict
    """

    # copy data struct
    processed_data = copy.deepcopy(data_struct)
    processed_image_data = copy.deepcopy(annotation_figure_struct)
    annotation = copy.deepcopy(origin_tasks)

    # if annotation["id"] != 165:
    #     continue

    ### id
    if "paper_id" in annotation["data"].keys():
        processed_data["id"]["pub_id"] = annotation["data"]["paper_id"]
    elif "file_name" in annotation["data"].keys():
        processed_data["id"]["pub_id"] = annotation["data"]["file_name"]
    processed_data["id"]["task_id"] = int(annotation["id"])
    # processed_data["id"]["annotations_id"] = annotation["annotations"][-1]["id"]
    processed_data["id"]["annotations_id"], annotation_index = get_correct_annotations_id(annotation["annotations"])

    ### NOTE check if already processed
    output_path = os.path.join(annotation_root, processed_data["id"]["pub_id"] + ".json")
    if os.path.exists(output_path) or processed_data["id"]["annotations_id"] == -1:
        return processed_data["id"]["task_id"], True

    try:
        ### origin_data contains image path and captions
        processed_data["origin_data"] = annotation["data"]
        processed_data["origin_data"]["image_captions"] = remove_images_path_prefix(
            annotation["data"]["image_captions"]
        )

        ### meta_data
        # processed_data["meat_data"]["topic"] = ["Science and Technology", "arxiv"]
        processed_data["meat_data"]["topic"] = get_topic(
            agent=agent, captions=processed_data["origin_data"]["image_captions"]
        )

        for result in annotation["annotations"][annotation_index]["result"]:
            ### annotations_result_questions
            if result["to_name"] == "global_task_context":
                if result["from_name"] == "if_can_be_labeled_question":
                    processed_data["annotations_result_questions"]["if_can_be_labeled"] = (
                        result["value"]["choices"][0].lower() == "yes"
                    )
                elif result["from_name"] == "question_type":
                    processed_data["annotations_result_questions"]["question_type"] = result["value"]["choices"][0]
                elif result["from_name"] == "answer_type":
                    processed_data["annotations_result_questions"]["answer_type"] = result["value"]["choices"][0]
                elif result["from_name"] == "qa_lements_taxonomy":
                    processed_data["annotations_result_questions"]["qa_elements_taxonomy"] = qa_elements_ch2eng(
                        result["value"]["taxonomy"]
                    )
                elif result["from_name"] == "can be answered":
                    processed_data["annotations_result_questions"]["question_can_be_answered"] = (
                        result["value"]["choices"][0].lower() == "yes"
                    )
                elif result["from_name"] == "difficulty_of_the_question":
                    processed_data["annotations_result_questions"]["difficulty"] = result["value"]["rating"]
                elif result["from_name"] == "charts_used":
                    text = re.split(r"[,，.。]", result["value"]["text"][0])  # NOTE split image index by ","
                    processed_data["annotations_result_questions"]["charts_used"] = {
                        "nums": len(text),
                        "names": [int(x) for x in text],
                    }
                elif result["from_name"] == "question":
                    if if_english_dir:
                        processed_data["annotations_result_questions"]["question"] = ch2eng(
                            agent=agent, origin_text=result["value"]["text"][-1]
                        )
                    else:
                        processed_data["annotations_result_questions"]["question"] = result["value"]["text"][-1]
                elif result["from_name"] == "answer":
                    if if_english_dir:
                        processed_data["annotations_result_questions"]["answer"] = ch2eng(
                            agent=agent, origin_text=result["value"]["text"][-1]
                        )
                    else:
                        processed_data["annotations_result_questions"]["answer"] = result["value"]["text"][-1]

            ### annotations_result_figures annotation result for every figures
            elif result["to_name"].startswith("image"):
                image_index = result["to_name"].split("_")[-1]
                # chart type
                if result["from_name"] == ("chart_type_" + image_index):
                    processed_image_data = append_processed_image_data(
                        processed_image_data, image_index, "chart_type", result["value"]["choices"]
                    )
                # if has sub charts
                elif result["from_name"] == ("sub_chart_" + image_index):
                    processed_image_data = append_processed_image_data(
                        processed_image_data,
                        image_index,
                        "if_sub_charts",
                        result["value"]["choices"][-1].lower() == "true",
                    )
                # image name
                elif result["from_name"] == ("image_name_" + image_index):
                    processed_image_data = append_processed_image_data(
                        processed_image_data, image_index, "name", result["value"]["choices"][0]
                    )
                # numbers of subcharts
                elif result["from_name"] == ("number_sub_chart_" + image_index):
                    processed_image_data = append_processed_image_data(
                        processed_image_data, image_index, "number_of_sub_chart", int(result["value"]["text"][-1])
                    )
                # translation to caption
                elif result["from_name"] == ("translation_to_caption_" + image_index):
                    processed_image_data = append_processed_image_data(
                        processed_image_data, image_index, "caption_translation", result["value"]["text"][-1]
                    )
        # merge annotation for figure to output dict
        processed_data["annotations_result_figures"] = transfor_processed_image_data2processed_data(
            processed_image_data
        )
        processed_data = post_process_processed_data(processed_data, processed_image_data)
        return processed_data, True
    except Exception as e:
        tqdm.write(f"error: {e}")
        return processed_data["id"]["task_id"], False


class Annotations:
    def __init__(self) -> None:
        self.export_path = "project/chartqa/data/label_studio/data_annotated/origin_annotations"
        self.processed_path = "project/chartqa/data/label_studio/data_annotated/processed_annotations"
        self.image_root_path = "project/chartqa/data/label_studio/data_to_annotate"
        self.processed_result_path = "project/chartqa/data/label_studio/data_annotated/process_result"
        self.nonannotation_path = "project/chartqa/data/label_studio/data_annotated/-1"

        self.data_dirs = [
            "CNNIC",
            "guizhou",
            "nation",
            "oecd",
            "spiqa",
            "statista",
            "spiqa1",
            "pewresearch",
        ]  # all data needs to process
        self.english_data_dirs = [
            "oecd",
            "spiqa",
            "statista",
            "spiqa1",
            "pewresearch",
        ]  # data needs extra translation(ch2eng) process

    def process_one_source(self, data_name: str, agent):
        data_path = os.path.join(self.export_path, data_name + ".json")
        annotation_root = os.path.join(self.processed_path, data_name)
        os.makedirs(annotation_root, exist_ok=True)

        result = {"success": [], "failed": []}

        with open(data_path, "r", encoding="utf-8") as f:
            annotation_datas = json.load(f)
            # for annotation in annotation_datas:
            for annotation in tqdm(list(annotation_datas), desc=f"Processing {data_name}"):
                # processed_data = post_process_processed_data(processed_data, processed_image_data)
                task_id = int()
                processed_data, if_sucess = process_origin_tasks_jsons(
                    origin_tasks=annotation,
                    data_struct=data_struct,
                    annotation_figure_struct=annotation_figure_struct,
                    agent=agent,
                    annotation_root=annotation_root,
                    if_english_dir=(data_name in self.english_data_dirs),
                )

                if if_sucess:
                    if isinstance(processed_data, dict):
                        output_path = os.path.join(annotation_root, processed_data["id"]["pub_id"] + ".json")
                        with open(output_path, "w", encoding="utf-8") as f:
                            json.dump(processed_data, f, indent=4, ensure_ascii=False)  # sort_keys=True
                        task_id = processed_data["id"]["task_id"]
                        tqdm.write(f"processed: {task_id}")
                    else:
                        task_id = processed_data
                        tqdm.write(f"{task_id} already processed")
                    result["success"].append(task_id)
                else:
                    task_id = processed_data
                    tqdm.write(f"failed durling: {processed_data}")
                    result["failed"].append(task_id)
        return result

    def process_all_sources(self, data_dirs, max_workers=5):
        text_qwen_chat = GPT(model="qwen3_32b", vendor="", stream=False, temperature=0.2)
        worker_func = partial(self.process_one_source, agent=text_qwen_chat)
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_data = {executor.submit(worker_func, name): name for name in data_dirs}
            for future in tqdm(future_to_data, total=len(data_dirs), desc="Overall Progress"):
                data_name = future_to_data[future]
                result_summary = future.result()
                with open(os.path.join(self.processed_result_path, data_name + ".json"), "w") as f:
                    json.dump(result_summary, f, ensure_ascii=False, indent=4)
                tqdm.write(f"Finished processing {data_name}: {result_summary}")

    def filter_processed(self, data_dirs):
        for dir_name in data_dirs:
            data_dir = os.path.join(self.processed_path, dir_name)
            nonannotation_dir = os.path.join(self.nonannotation_path, dir_name)
            os.makedirs(nonannotation_dir, exist_ok=True)

            data_path = Path(data_dir)
            files_iterator = (p for p in data_path.iterdir() if p.is_file())
            for file_name in files_iterator:
                file_path = os.path.join(data_dir, file_name.name)
                with open(file_path, "r") as f:
                    data = json.load(f)
                    annotations_id = data["id"]["annotations_id"]
                if annotations_id == -1:
                    output_path = os.path.join(nonannotation_dir, file_name.name)
                    shutil.move(file_path, output_path)

    def collect_annotations(self, data_dirs):
        annotations = []
        for dir_name in data_dirs:
            data_dir = os.path.join(self.processed_path, dir_name)
            data_path = Path(data_dir)
            files_iterator = (p for p in data_path.iterdir() if p.is_file())
            for file_name in files_iterator:
                file_path = os.path.join(data_dir, file_name.name)
                with open(file_path, "r") as f:
                    data = json.load(f)
                    annotations.append(data)
        all_annotations_path = os.path.join(self.processed_path, "all_annotations.json")
        with open(all_annotations_path, "w") as f:
            json.dump(annotations, f, indent=2, ensure_ascii=False)
        return annotations

    def count_results(self, data_path):
        count_result = {"task": 0, "qa": 0, "charts": 0, "subcharts": 0, "failed": [], "failed_id": []}
        with open(data_path, "r") as f:
            data = json.load(f)
            for task in data:  # each task
                count_result["task"] += 1
                if task["annotations_result_questions"]["if_can_be_labeled"] == True:
                    count_result["qa"] += 1
                for key in task["annotations_result_figures"].keys():
                    figure = task["annotations_result_figures"][key]
                    count_result["charts"] += 1
                    if "if_sub_charts" in figure.keys():
                        if figure["if_sub_charts"]:
                            if "number_of_sub_chart" in figure.keys():
                                count_result["subcharts"] += figure["number_of_sub_chart"]
                            else:
                                count_result["failed"].append(task["id"]["pub_id"])
                                count_result["failed_id"].append(task["id"]["task_id"])
                        else:
                            count_result["subcharts"] += 1
                    else:
                        count_result["failed"].append(task["id"]["pub_id"])
                        count_result["failed_id"].append(task["id"]["task_id"])
                        count_result["subcharts"] += 1
        print(count_result)
        return count_result


if __name__ == "__main__":
    # Connect to the Label Studio API and check the connection
    # client = None
    # deepseek_chat = DeepseekChat(temp=1.3)

    annotation = Annotations()

    data_path = "project/chartqa/data/label_studio/data_annotated/processed_annotations/all_annotations.json"

    annotation.process_all_sources(annotation.data_dirs, max_workers=5)
    annotation.filter_processed(annotation.data_dirs)
    annotation.collect_annotations(annotation.data_dirs)
    annotation.count_results(
        data_path
    )  # {'task': 102, 'qa': 89, 'charts': 428, 'subcharts': 1538, 'failed': ['1706.03353v3', '1805.08204v4', '1807.04801v3', '1807.04801v3'], 'failed_id': [134, 213, 236, 236]}
