import random
import json, copy, os, sys
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from agent import CommonAgents
from prompts import QAPrompts, EVALPrompts

sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
)
from gpt_api.unigpt import GPT
from labelstudio.common_prompts import encode_image


class EVAL:

    def __init__(self, agent, result_dir, if_cot=True) -> None:
        self.agent = agent
        self.eval_agent = GPT(model="qwen3_32b", vendor="", stream=False, temperature=0.1)
        self.image_root_dir = "project/chartqa/data/label_studio/data_to_annotate"
        self.result_dir = result_dir
        self.data_index = "project/chartqa/data/label_studio/data_annotated/processed_annotations/all_annotations.json"
        self.qa_prompt = QAPrompts(if_cot=if_cot)
        self.eval_prompt = EVALPrompts()

    def run_one_prediction(self, task, if_local=False):
        qa_data = copy.deepcopy(task)
        qa_data["qa"] = {}
        if_can_be_labeled = task["annotations_result_questions"]["if_can_be_labeled"]
        if if_can_be_labeled:
            qa_data["qa"]["question"] = task["annotations_result_questions"]["question"]
            if task["annotations_result_questions"]["question_can_be_answered"]:
                qa_data["qa"]["gt_answer"] = task["annotations_result_questions"]["answer"]
            else:
                qa_data["qa"]["gt_answer"] = "unanswerable"
            qa_data["qa"]["answer_type"] = task["annotations_result_questions"]["answer_type"]
            captions = []
            image_paths = []
            for image_key in task["origin_data"]["image_captions"].keys():
                image_paths.append(task["origin_data"]["image_captions"][image_key]["image"])
                captions.append(task["origin_data"]["image_captions"][image_key]["caption"])

            if qa_data["qa"]["answer_type"] == "open-end sentence":
                qa_data["qa"]["question_prompt"] = self.qa_prompt.qa_prompt(
                    self.qa_prompt.qa_open_end(), captions, qa_data["qa"]["question"]
                )
            elif qa_data["qa"]["answer_type"] == "Approximate value":
                qa_data["qa"]["question_prompt"] = self.qa_prompt.qa_prompt(
                    self.qa_prompt.qa_approximate(), captions, qa_data["qa"]["question"]
                )
            elif qa_data["qa"]["answer_type"] == "open-end vocabulary":
                qa_data["qa"]["question_prompt"] = self.qa_prompt.qa_prompt(
                    self.qa_prompt.qa_extract(), captions, qa_data["qa"]["question"]
                )
            elif qa_data["qa"]["answer_type"] == "bool":
                qa_data["qa"]["question_prompt"] = self.qa_prompt.qa_prompt(
                    self.qa_prompt.qa_bool(), captions, qa_data["qa"]["question"]
                )
            elif qa_data["qa"]["answer_type"] == "multi choice":
                qa_data["qa"]["question_prompt"] = self.qa_prompt.qa_prompt(
                    self.qa_prompt.qa_multi_choices(), captions, qa_data["qa"]["question"]
                )

            if "question_prompt" in qa_data["qa"].keys():
                qa_prompt = [
                    {"type": "text", "text": qa_data["qa"]["question_prompt"]},
                ]

                if if_local:
                    for image_path in image_paths:
                        qa_prompt.append(
                            {
                                "type": "image",
                                "image": f"{os.path.join(self.image_root_dir, image_path)}",
                            }
                        )
                else:
                    for image_path in image_paths:
                        qa_prompt.append(
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{encode_image(os.path.join(self.image_root_dir, image_path))}"
                                },
                            }
                        )
                qa_message = [
                    {"role": "system", "content": self.qa_prompt.system_prompt},
                    {"role": "user", "content": qa_prompt},
                ]

                # tqdm.write(f"{qa_message}")
                qa_data["qa"]["pred_answer"], _, _, _ = self.agent.send_chat_request(qa_message)
                # tqdm.write(f"{qa_data["qa"]["pred_answer"]}")

                if qa_data["qa"]["answer_type"] == "open-end sentence":
                    qa_data["qa"]["evaluation_prompt"] = self.eval_prompt.eval_prompt(
                        self.eval_prompt.eval_open_end(),
                        qa_data["qa"]["question"],
                        qa_data["qa"]["pred_answer"],
                        qa_data["qa"]["gt_answer"],
                    )
                elif qa_data["qa"]["answer_type"] == "Approximate value":
                    qa_data["qa"]["evaluation_prompt"] = self.eval_prompt.eval_prompt(
                        self.eval_prompt.eval_approximate(),
                        qa_data["qa"]["question"],
                        qa_data["qa"]["pred_answer"],
                        qa_data["qa"]["gt_answer"],
                    )
                elif qa_data["qa"]["answer_type"] == "open-end vocabulary":
                    qa_data["qa"]["evaluation_prompt"] = self.eval_prompt.eval_prompt(
                        self.eval_prompt.eval_extract(),
                        qa_data["qa"]["question"],
                        qa_data["qa"]["pred_answer"],
                        qa_data["qa"]["gt_answer"],
                    )
                elif qa_data["qa"]["answer_type"] == "bool":
                    qa_data["qa"]["evaluation_prompt"] = self.eval_prompt.eval_prompt(
                        self.eval_prompt.eval_bool(),
                        qa_data["qa"]["question"],
                        qa_data["qa"]["pred_answer"],
                        qa_data["qa"]["gt_answer"],
                    )
                elif qa_data["qa"]["answer_type"] == "multi choice":
                    qa_data["qa"]["evaluation_prompt"] = self.eval_prompt.eval_prompt(
                        self.eval_prompt.eval_multi_choices(),
                        qa_data["qa"]["question"],
                        qa_data["qa"]["pred_answer"],
                        qa_data["qa"]["gt_answer"],
                    )

                eval_message = [
                    {"role": "system", "content": self.eval_prompt.system_prompt},
                    {"role": "user", "content": qa_data["qa"]["evaluation_prompt"]},
                ]

                qa_data["qa"]["eval_result"], _, _, _ = self.eval_agent.send_chat_request(eval_message)
                return qa_data
        return None

    def process_task(self, task, if_local=False):
        # try:
        file_name = task["id"]["pub_id"]
        output_path = os.path.join(self.result_dir, file_name)
        if os.path.exists(output_path):
            return f"success"
        qa_data = self.run_one_prediction(task=task, if_local=if_local)
        if qa_data is not None:
            os.makedirs(self.result_dir, exist_ok=True)
            with open(output_path, "w", encoding="utf-8") as f:
                json.dump(qa_data, f, ensure_ascii=False, indent=2)
            return f"success"
        # except Exception as e:
        #     print(e)
        #     return f"failed"
        return None

    def run_one_prediction_local(self, if_local=True):
        with open(self.data_index, "r") as f:
            tasks = json.load(f)
            for task in tqdm(tasks, desc=f"process to {self.result_dir}"):
                self.process_task(task, if_local=if_local)

    def run_all_prediction(self, max_workers=10):
        with open(self.data_index, "r") as f:
            tasks = json.load(f)

            # for task in tasks:
            #     qa_data = self.run_one_prediction(task=task)
            #     if qa_data is not None:
            #         with open(os.path.join(self.result_dir, qa_data["id"]["pub_id"]), "w") as f:
            #             json.dump(qa_data, f, ensure_ascii=False, indent=2)
            #             exit()

            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                results = list(tqdm(executor.map(self.process_task, tasks), total=len(tasks), desc="processing"))

    def filter_processed(self):
        data_path = Path(self.result_dir)
        files_iterator = (p for p in data_path.iterdir() if p.is_file())
        result_count = {"true": 0, "false": 0, "failed": 0}
        for file_name in files_iterator:
            file_path = os.path.join(self.result_dir, file_name.name)
            with open(file_path, "r") as f:
                result = json.load(f)
                eval_result = result["qa"]["eval_result"]
                if "true" in eval_result.lower():
                    result_count["true"] += 1
                elif "false" in eval_result.lower():
                    result_count["false"] += 1
                else:
                    result_count["failed"] += 1
        print(result_count)


if __name__ == "__main__":
    os.environ["http_proxy"] = "http://127.0.0.1:61110"
    os.environ["https_proxy"] = "http://127.0.0.1:61110"

    ### gemini-2.5-pro
    agent = CommonAgents(engine="gemini-2.5-pro")
    gpt_agent = GPT(model="azure-gpt-4o", vendor="azure", stream=False, temperature=0.1)
    eval = EVAL(agent, "projects/chartqa/result/gemini2.5pro")
    eval.run_all_prediction()
    # eval.filter_processed()
