import os, sys, time

os.environ["HF_HOME"] = "cache/hf"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import torch
from transformers import AutoProcessor, AutoModelForSeq2SeqLM
import base64
from loguru import logger
from PIL import Image
import requests

sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
)
from labelstudio.common_prompts import encode_image
from chartmoe import ChartMoE_Robot


ChartMoE_HF_PATH = "models/chartmoe"


class TransformerAgent:
    def __init__(self) -> None:
        self.tmp_dir = "project/chartqa/tmp/"

        self.robot = ChartMoE_Robot()

    def convert_openai_to_prompt_and_concated_image(self, messages):
        text_parts = []
        pil_images = []
        for message in messages:
            if message["role"] == "system":
                continue
            content = message.get("content", [])
            if not isinstance(content, list):
                content = [{"type": "text", "text": content}]
            for item in content:
                if item["type"] == "text":
                    text_parts.append(item["text"])
                elif item["type"] == "image":
                    image_path = item["image"]
                    try:
                        image = Image.open(image_path).convert("RGB")
                        pil_images.append(image)
                    except FileNotFoundError:
                        logger.warning(f"Warning: Image file not found at {image_path}. Skipping.")
        # return text_parts, pil_images
        final_image = None
        if pil_images:
            if len(pil_images) == 1:
                final_image = pil_images[0]
            else:
                total_width = sum(img.width for img in pil_images)
                max_height = max(img.height for img in pil_images)

                stitched_image = Image.new("RGB", (total_width, max_height))

                current_x = 0
                for img in pil_images:
                    stitched_image.paste(img, (current_x, 0))
                    current_x += img.width
                final_image = stitched_image

        full_text = " ".join(text_parts)
        prompt_text = f"{full_text}"
        return prompt_text, final_image

    def send_chat_request(self, messages):
        prompt_text, image = self.convert_openai_to_prompt_and_concated_image(messages)
        image_path = None
        try:
            if image:
                # Ensure the temporary directory exists
                if not os.path.exists(self.tmp_dir):
                    os.makedirs(self.tmp_dir)
                timestamp = int(time.time() * 1000)
                image_path = os.path.join(self.tmp_dir, f"{timestamp}.png")
                image.save(image_path)

            with torch.cuda.amp.autocast():
                # with torch.amp.autocast(device_type="cuda"):
                result, history = self.robot.chat(image_path=image_path, question=prompt_text)
                return result, history, None, None
        finally:
            if image_path and os.path.exists(image_path):
                os.remove(image_path)


if __name__ == "__main__":
    # os.environ["CUDA_VISIBLE_DEVICES"] = "9"
    model_path = "models/chartmoe"  # NOTE

    ChartMoE_HF_PATH = "models/chartmoe"

    # miniforge3/condabin/conda run -n chartmoe --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_chartmoe.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent()
    eval = EVAL(agent, os.path.join("project/chartqa/result/cot", model_name))
    eval.run_one_prediction_local()
