import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, PaliGemmaForConditionalGeneration
import base64
import os, sys, copy
from loguru import logger
from PIL import Image
import requests

sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
)
from labelstudio.common_prompts import encode_image


class TransformerAgent:
    def __init__(self, model_path) -> None:
        self.model = PaliGemmaForConditionalGeneration.from_pretrained(model_path)
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)  # type: ignore
        logger.info(f"Model loaded on device: {self.model.device}")

    def convert_openai_to_paligemma(self, messages):
        text_parts = []
        pil_images = []
        for message in messages:
            if message["role"] == "system":
                continue
            content = message.get("content", [])
            if not isinstance(content, list):
                content = [{"type": "text", "text": content}]
            for item in content:
                if item["type"] == "text":
                    text_parts.append(item["text"])
                elif item["type"] == "image":
                    image_path = item["image"]
                    try:
                        image = Image.open(image_path).convert("RGB")
                        pil_images.append(image)
                    except FileNotFoundError:
                        logger.warning(f"Warning: Image file not found at {image_path}. Skipping.")
        final_image = None
        if pil_images:
            if len(pil_images) == 1:
                final_image = pil_images[0]
            else:
                total_width = sum(img.width for img in pil_images)
                max_height = max(img.height for img in pil_images)

                stitched_image = Image.new("RGB", (total_width, max_height))

                current_x = 0
                for img in pil_images:
                    stitched_image.paste(img, (current_x, 0))
                    current_x += img.width
                final_image = stitched_image

        full_text = " ".join(text_parts)
        prompt_text = f"{full_text}"
        return prompt_text, final_image

    def send_chat_request(self, messages):
        prompt_input = ""
        image_input = []
        messages_for_processor = copy.deepcopy(messages)
        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                if not isinstance(message.get("content"), str):
                    message["content"] = [{"type": "text", "text": ""}]
                else:
                    prompt_input += message.get("content")
                    message["content"] = [{"type": "text", "text": message.get("content")}]
                continue
            new_content = []
            for item in message["content"]:
                item_type = item.get("type")
                if item_type == "image":
                    image_path = item.get("image")
                    if image_path and isinstance(image_path, str):
                        try:
                            new_content.append({"type": "image", "url": image_path})
                            image_input.append(Image.open(image_path).convert("RGB"))
                        except FileNotFoundError:
                            logger.warning(f"image not found: {image_path}, skiped")
                elif item_type == "text":
                    text_content = item.get("text")
                    if isinstance(text_content, str) and text_content.strip():
                        new_content.append(item)
                        prompt_input += text_content
            message["content"] = new_content

        inputs = self.processor(text=prompt_input, images=[image_input], return_tensors="pt")
        # inputs = self.processor(text=prompt_input, images=image_input, return_tensors="pt")
        prompt_length = inputs["input_ids"].shape[1]
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        generate_ids = self.model.generate(**inputs, num_beams=4, max_new_tokens=4096)

        output_text = self.processor.batch_decode(
            generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        return output_text, None, None, None


if __name__ == "__main__":
    # os.environ["CUDA_VISIBLE_DEVICES"] = "3,9"
    model_path = "models/chartgemma"  # NOTE

    # miniforge3/condabin/conda run -n chartqa --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_chartgemma.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path=model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result/cot", model_name))
    eval.run_one_prediction_local()
