import torch
import os, sys, time, base64
from loguru import logger
from PIL import Image

sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
)
from labelstudio.common_prompts import encode_image

from gradio_client import Client, handle_file


class TransformerAgent:
    def __init__(self, url):
        self.client = Client(url)
        self.tmp_dir = "project/chartqa/tmp/"

    def convert_openai_to_gradio(self, messages):
        text_parts = []
        pil_images = []
        for message in messages:
            if message["role"] == "system":
                continue
            content = message.get("content", [])
            if not isinstance(content, list):
                content = [{"type": "text", "text": content}]
            for item in content:
                if item["type"] == "text":
                    text_parts.append(item["text"])
                elif item["type"] == "image":
                    image_path = item["image"]
                    try:
                        image = Image.open(image_path).convert("RGB")
                        pil_images.append(image)
                    except FileNotFoundError:
                        logger.warning(f"Warning: Image file not found at {image_path}. Skipping.")
        final_image = None
        if pil_images:
            if len(pil_images) == 1:
                final_image = pil_images[0]
            else:
                total_width = sum(img.width for img in pil_images)
                max_height = max(img.height for img in pil_images)

                stitched_image = Image.new("RGB", (total_width, max_height))

                current_x = 0
                for img in pil_images:
                    stitched_image.paste(img, (current_x, 0))
                    current_x += img.width
                final_image = stitched_image

        full_text = " ".join(text_parts)
        prompt_text = f"{full_text}"
        return prompt_text, final_image

    def send_chat_request(self, messages):
        prompt_text, image = self.convert_openai_to_gradio(messages)
        image_path = None
        try:
            if image:
                # Ensure the temporary directory exists
                if not os.path.exists(self.tmp_dir):
                    os.makedirs(self.tmp_dir)
                timestamp = int(time.time() * 1000)
                image_path = os.path.join(self.tmp_dir, f"{timestamp}.png")
                image.save(image_path)
                image_to_process = handle_file(image_path)
            else:
                # Handle the case where no image is provided
                image_to_process = None

            result = self.client.predict(
                image=image_to_process,
                prompt=prompt_text,
                show_thinking=False,
                do_sample=False,
                text_temperature=0.1,
                max_new_tokens=4096,
                api_name="/process_understanding",
            )
            return result, None, None, None
        finally:
            if image_path and os.path.exists(image_path):
                os.remove(image_path)


if __name__ == "__main__":
    # TMPDIR="project/chartqa/env/projects/BAGEL/tmp" python project/chartqa/env/projects/BAGEL/app.py
    os.environ["CUDA_VISIBLE_DEVICES"] = "7,9"
    url = "http://localhost:7860/"

    # miniforge3/condabin/conda run -n bagel --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_bagel.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = "bagel"
    agent = TransformerAgent(url)
    eval = EVAL(agent, os.path.join("project/chartqa/result", model_name))
    eval.run_one_prediction_local()
