# Import the SDK and the client module
# from label_studio_sdk.client import LabelStudio
from collections import deque
import json
import copy
import sys, os
import shutil
from loguru import logger
import random
from pathlib import Path
from transformers import AutoTokenizer
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

from common_prompts import translate as translate
from common_prompts import LABEL_STUDIO_URL as LABEL_STUDIO_URL
from common_prompts import API_KEY as API_KEY

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gpt_api.unigpt import GPT
from gpt_tools.deepseek_tools import DeepseekChat

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from markdown_it.common_tools import if_chart as if_chart
from markdown_it.common_tools import extract_context as extract_context


class PD:
    def __init__(self) -> None:
        self.struct_data = {
            "data": {
                "file_name": "",
                "image_captions": [
                    # {
                    #     "caption": "",
                    #     "image": "",
                    #     "image_name": "",
                    #     "caption_translation": ""
                    # }
                ],
            }
        }
        self.vl_qwen_chat = GPT(model="qwen2.5-vl-7b", vendor="", stream=False, temperature=0.2)
        self.text_qwen_chat = GPT(model="qwen3_32b", vendor="", stream=False, temperature=0.2)
        self.local_tokenizer_path = Path("models/Qwen2.5-VL-7B-Instruct")
        # tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path, trust_remote_code=True)

        self.chinese_dirs = ["CNNIC", "Guizhou", "NationalBureauOfStatistics", "telecomworld"]
        self.english_dirs = ["GlobalTerrorismDatabase", "oced", "worldbank"]
        self.data_dirs = copy.deepcopy(self.chinese_dirs)
        self.data_dirs.extend(self.english_dirs)
        self.json_root_dir = "data/chartdata/filtered_images"
        self.output_dir = "data/chartQA/label_studio_formal/data_to_annotate"
        self.label_studio_image_prefix = "/data/local-files/?d="  # pewresearch/images/"

    def get_image_path(self, dataset_name: str, data_name: str, image_related_path: str) -> tuple[str, str]:
        image_root_dir = "data/chartdata/processed/mineru_python"
        image_name = os.path.basename(image_related_path)
        image_path = os.path.join(image_root_dir, dataset_name, data_name, "auto", image_related_path)
        return image_name, image_path

    def get_concise_context(self, image_name, context) -> str:
        concise_context = ""
        concise_context = extract_context(self.text_qwen_chat, image_name, context)
        return concise_context

    def process_data_dir(self, data_dir):
        output_path = os.path.join(self.output_dir, data_dir)
        output_images_path = os.path.join(output_path, "images")
        output_json_dir = os.path.join(output_path, "jsons")
        os.makedirs(output_path, exist_ok=True)
        os.makedirs(output_images_path, exist_ok=True)
        os.makedirs(output_json_dir, exist_ok=True)

        json_dir = os.path.join(self.json_root_dir, data_dir)

        with os.scandir(json_dir) as entries:
            failed_filenames = []
            # for entry in entries:  # 每条具体数据
            for entry in tqdm(list(entries), desc=f"Processing {data_dir}"):
                if entry.name.endswith(".json"):
                    data_name = entry.name.split(".json")[0]  # 每条数据的标题
                    label_studio_data_image_prefix = (
                        self.label_studio_image_prefix + data_dir + "/images/" + data_name + "/"
                    )  # 后面加"image_path"
                    ls_json = copy.deepcopy(self.struct_data)

                    ls_json["data"]["file_name"] = entry.name
                    output_json_path = os.path.join(output_json_dir, entry.name)
                    if os.path.exists(output_json_path):  # INFO jump the exist output
                        # logger.info(f"file exist: {output_json_path}")
                        tqdm.write(f"file exist: {output_json_path}")
                        continue
                    try:
                        entry_path = os.path.join(json_dir, entry.name)
                        with open(entry_path, "r") as f:  # 读取具体数据
                            entry_jsons = json.load(f)
                            if len(entry_jsons) < 2:  # 需要至少2各图表
                                continue
                            for entry_json in entry_jsons:
                                image_caption = {
                                    "caption": "",
                                    "image": "",
                                    "image_name": "",
                                    "caption_translation": "",
                                }
                                md_image_name = entry_json["image_name"]
                                # image_caption["caption"] = entry_json["context"]  # TODO context可以再缩减
                                image_caption["caption"] = self.get_concise_context(
                                    md_image_name, entry_json["context"]
                                )
                                tqdm.write(f"concised context: {image_caption['caption']}")
                                image_name, image_path = self.get_image_path(
                                    data_dir, data_name, entry_json["image_path"]
                                )  # image的路径
                                image_caption["image"] = label_studio_data_image_prefix + image_name
                                image_caption["image_name"] = image_name
                                if data_dir in self.english_dirs:
                                    image_caption["caption_translation"] = translate(
                                        self.text_qwen_chat, image_caption["caption"]
                                    )
                                    tqdm.write(f"translate: {image_caption['caption_translation']}")
                                output_image_path = os.path.join(output_images_path, data_name)
                                os.makedirs(output_image_path, exist_ok=True)
                                shutil.copy(
                                    os.path.join(image_path),
                                    os.path.join(output_image_path, image_name),
                                )
                                ls_json["data"]["image_captions"].append(image_caption)
                        with open(output_json_path, "w", encoding="utf-8") as f:
                            json.dump(ls_json, f, ensure_ascii=False, indent=4)
                            # logger.info(f"saved file: {output_json_path}")
                            tqdm.write(f"saved file: {output_json_path}")
                    except Exception as e:
                        failed_filenames.append(entry.name)
                        logger.error(f"{e}")
                    # exit()  # WARN exit()
            failed_dir = "data/chartQA/label_studio/label_studio_tools/failed"
            with open(os.path.join(failed_dir, data_dir + ".json"), "w", encoding="utf-8") as f:
                json.dump(failed_filenames, f, ensure_ascii=False, indent=4)

    def decrease_nums_of_charts(self, data_dir):
        origin_path = os.path.join(self.output_dir, data_dir)
        origin_json_dir = os.path.join(origin_path, "jsons")
        output_json_dir = os.path.join(self.output_dir, "filtered_num_of_charts", data_dir, "jsons")
        os.makedirs(output_json_dir, exist_ok=True)
        with os.scandir(origin_json_dir) as entries:
            failed_filenames = []
            # for entry in entries:  # 每条具体数据
            for entry in tqdm(list(entries), desc=f"Processing {data_dir}"):
                if entry.name.endswith(".json"):
                    origin_json_path = os.path.join(origin_json_dir, entry.name)
                    output_json_path = os.path.join(output_json_dir, entry.name)
                    with open(origin_json_path, "r") as f:
                        entry_jsons = json.load(f)
                        copyed_entry_jsons = copy.deepcopy(entry_jsons)
                        copyed_entry_jsons["data"]["image_captions"] = []
                    if len(entry_jsons["data"]["image_captions"]) <= 3:
                        continue
                    numOfimages = range(len(entry_jsons["data"]["image_captions"]))
                    k = random.randint(1, 9)
                    if len(numOfimages) > k:
                        random_numbers = random.sample(numOfimages, k)
                    else:
                        random_numbers = numOfimages
                    i = 0
                    for image_caption in entry_jsons["data"]["image_captions"]:
                        if i in random_numbers:
                            copyed_entry_jsons["data"]["image_captions"].append(image_caption)
                        i += 1
                    with open(output_json_path, "w", encoding="utf-8") as f:
                        json.dump(copyed_entry_jsons, f, ensure_ascii=False, indent=4)
                        # logger.info(f"saved file: {output_json_path}")
                        tqdm.write(f"saved file: {output_json_path}")
                    # exit()

    def process_all_data_dirs_threaded(self, max_workers=None):
        # Define the absolute maximum number of workers (CPU cores) to use.
        WORKER_CAP = 20
        # Use ThreadPoolExecutor for managing threads
        # max_workers=None uses a default based on core count * 5 for I/O tasks,
        # which is often good for I/O-bound operations.
        if max_workers is None:
            num_workers = WORKER_CAP
        else:
            num_workers = min(max_workers, WORKER_CAP)
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            # Submit tasks for each data_dir
            # executor.map applies the function to each item in the iterable
            # tqdm wraps the executor.map to show progress
            # results = list(
            #     tqdm(
            #         executor.map(lambda dd: self.process_data_dir(dd), self.data_dirs),
            #         total=len(self.data_dirs),
            #         desc="Processing Data Directories",
            #     )
            # )
            results = list(
                tqdm(
                    executor.map(lambda dd: self.decrease_nums_of_charts(dd), self.data_dirs),
                    total=len(self.data_dirs),
                    desc="Processing Data Directories",
                )
            )
        return results


if __name__ == "__main__":
    pd = PD()
    # for data_dir in pd.data_dirs:  # 每个数据源
    #     pd.process_data_dir(data_dir)
    pd.process_all_data_dirs_threaded(max_workers=10)


# NOTE failed
# ['about-1-in-4-us-teachers-say-their-school-went-into-a-gun-related-lockdown-in-the-last-school-year.json', 'income-inequality-is-greater-among-chinese-americans-than-any-other-asian-origin-group-in-the-us.json', 'key-facts-about-americans-and-guns.json', 'republicans-think-economy-will-improve-over-the-next-year-democrats-expect-it-to-get-worse.json', 'striking-findings-from-2023.json', 'striking-findings-from-2024.json']
