import torch
from transformers import AutoProcessor, AutoModelForSeq2SeqLM, LlavaForConditionalGeneration
import base64
import os, sys, copy
from loguru import logger
from PIL import Image
import requests

sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
)
from labelstudio.common_prompts import encode_image


class TransformerAgent:
    def __init__(self, model_path) -> None:
        self.SYSTEM_PROMPT = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>"

        self.model = LlavaForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(model_path)
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.model = self.model.to(self.device)  # type: ignore
        logger.info(f"Model loaded on device: {self.model.device}")

    def convert_openai_to_llava(self, messages):
        text_parts = []
        pil_images = []
        for message in messages:
            if message["role"] == "system":
                continue
            content = message.get("content", [])
            if not isinstance(content, list):
                content = [{"type": "text", "text": content}]
            for item in content:
                if item["type"] == "text":
                    text_parts.append(item["text"])
                elif item["type"] == "image":
                    image_path = item["image"]
                    try:
                        image = Image.open(image_path).convert("RGB")
                        pil_images.append(image)
                    except FileNotFoundError:
                        logger.warning(f"Warning: Image file not found at {image_path}. Skipping.")
        # return text_parts, pil_images
        final_image = None
        if pil_images:
            if len(pil_images) == 1:
                final_image = pil_images[0]
            else:
                total_width = sum(img.width for img in pil_images)
                max_height = max(img.height for img in pil_images)

                stitched_image = Image.new("RGB", (total_width, max_height))

                current_x = 0
                for img in pil_images:
                    stitched_image.paste(img, (current_x, 0))
                    current_x += img.width
                final_image = stitched_image

        full_text = " ".join(text_parts)
        prompt_text = f"{full_text}"
        return prompt_text, final_image

    def send_chat_request(self, messages):
        prompt_text, images = self.convert_openai_to_llava(messages)
        # image_input = images[0] if images else None
        image_input = images if images else None
        prompt_input = "<image>\n" + prompt_text
        if image_input != None:
            inputs = self.processor(text=prompt_input, images=image_input, return_tensors="pt").to(
                self.model.device, torch.float16
            )
        else:
            inputs = self.processor(text=prompt_input, return_tensors="pt").to(self.model.device, torch.float16)

        prompt_length = inputs["input_ids"].shape[1]
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        generate_ids = self.model.generate(**inputs, num_beams=4, max_new_tokens=4096)
        output_text = self.processor.batch_decode(
            generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        return output_text, None, None, None


if __name__ == "__main__":
    # os.environ["CUDA_VISIBLE_DEVICES"] = "2,9"
    # model_path = "models/ChartInstruct-FlanT5-XL"  # NOTE
    model_path = "models/ChartInstruct-LLama2"  # NOTE

    # miniforge3/condabin/conda run -n chartinstruct --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_chartinstruct.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path=model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result/cot", model_name))
    eval.run_one_prediction_local()
