# Import the SDK and the client module
from label_studio_sdk.client import LabelStudio
import json
import copy
import sys, os
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock

from labelstudio_api import update_annotation_translation_images

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from gpt_api.unigpt import GPT
from gpt_tools.qwen_vl_plus import QwenVLPlus

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from common_prompts import translate, LABEL_STUDIO_URL, API_KEY


def update_file_translation_caption(processed_data, file_path, output_path):
    """add translation of captions in json from the task of labelstudio

    Parameters
    ----------
    processed_data : dict
        the dict which is parsed from label studio tasks
    file_path : str
        original json task file path
    output_path : str
        output json task file path
    """
    annotations_result_figures = processed_data["annotations_result_figures"]
    json_path = os.path.join(file_path, processed_data["id"]["pub_id"] + ".json")
    output_json_path = os.path.join(output_path, processed_data["id"]["pub_id"] + ".json")
    # processed_json = None
    if os.path.exists(json_path):
        with open(json_path, "r") as f:
            origin_file_json = json.load(f)
            processed_json = copy.deepcopy(
                origin_file_json
            )  # just add the "caption_translation" key-value to origin json
            for i in range(len(origin_file_json["data"]["image_captions"])):
                image_name = processed_json["data"]["image_captions"][i]["image_name"]
                if "caption_translation" in annotations_result_figures[image_name].keys():
                    processed_json["data"]["image_captions"][i]["caption_translation"] = annotations_result_figures[
                        image_name
                    ]["caption_translation"]
        with open(output_json_path, "w", encoding="utf-8") as f:
            json.dump(processed_json, f, ensure_ascii=False, indent=4)
            tqdm.write(f"saved file: {output_json_path}")


class PD:
    def __init__(self, MAX_WORKERS=10) -> None:
        self.failed_json_path = "chartqa/data/failed/failed_image_translate.json"
        self.failed_ids = []
        self.failed_ids_lock = Lock()
        self.MAX_WORKERS = MAX_WORKERS

        self.client = LabelStudio(base_url=LABEL_STUDIO_URL, api_key=API_KEY)
        self.text_qwen_chat = GPT(model="qwen3_32b", vendor="", stream=False, temperature=0.2)
        self.vl_qwen_chat = QwenVLPlus(engine="qwen-vl-plus-2025-05-07")

    def process_task(self, task):
        try:
            task_id, project_id, is_success = update_annotation_translation_images(
                task=task, client=self.client, vl_agent=self.vl_qwen_chat, text_agent=self.text_qwen_chat
            )
            if not is_success:
                with self.failed_ids_lock:
                    self.failed_ids.append({"task_id": task_id, "project_id": project_id})
            else:
                tqdm.write(f"FINISHED, project: {project_id}, task: {task_id}")
        except Exception as e:
            tqdm.write(f"ERROR processing task {task.id}: {e}")
            with self.failed_ids_lock:
                self.failed_ids.append({"task_id": task.id, "project_id": task.project, "error": str(e)})

    def process_projects(self):
        all_tasks = []
        projects = self.client.projects.list()
        for project in tqdm(list(projects), desc="Gathering Projects"):
            tasks_in_project = self.client.tasks.list(project=project.id)  # , include="id")
            all_tasks.extend(tasks_in_project)

        with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
            futures = {executor.submit(self.process_task, task) for task in all_tasks}

            for future in tqdm(as_completed(futures), total=len(all_tasks), desc="Processing Tasks"):
                try:
                    future.result()
                except Exception as e:
                    tqdm.write(f"A worker thread encountered a critical error: {e}")

        with open(self.failed_json_path, "w") as f:
            json.dump(self.failed_ids, f, ensure_ascii=False, indent=2)


if __name__ == "__main__":

    pd = PD(MAX_WORKERS=5)
    pd.process_projects()

    # projects = pd.client.projects.list()
    # for project in tqdm(list(projects), desc=f"process projects"):
    #     # tasks = pd.client.tasks.list(project=project.id, include="id")
    #     tasks = pd.client.tasks.list(project=3, include="id")  # WARN project id = 3
    #     for task in tqdm(list(tasks), desc=f"process tasks"):
    #         task_id = task.id
    #         task_id, project_id, is_success = update_annotation_translation_images(
    #             task=task, client=pd.client, vl_agent=pd.vl_qwen_chat, text_agent=pd.text_qwen_chat
    #         )
    #         if not is_success:
    #             pd.failed_ids.append({"task_id": task_id, "project_id": project_id})

    #         tqdm.write(f"finished: {task_id}")
    #         exit()  # HACK
    # with open(pd.failed_json_path, "w") as f:
    #     json.dump(pd.failed_ids, f, ensure_ascii=False, indent=2)
