import torch
from loguru import logger
from PIL import Image
import requests, os, json, copy, sys

sys.path.append("project/chartqa/env/projects/mPLUG-DocOwl/TinyChart")
from tinychart.model.builder import load_pretrained_model
from tinychart.mm_utils import get_model_name_from_path
from tinychart.eval.run_tiny_chart import inference_model
from tinychart.eval.eval_metric import parse_model_output, evaluate_cmds


class TransformerAgent:
    def __init__(self, model_path) -> None:
        self.device = "cuda"
        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
            model_path,
            model_base=None,
            model_name=get_model_name_from_path(model_path),
            device="cuda",  # device="cpu" if running on cpu
        )

    def convert_openai_to_gradio(self, messages):
        text_parts = []
        pil_images = []
        for message in messages:
            if message["role"] == "system":
                continue
            content = message.get("content", [])
            if not isinstance(content, list):
                content = [{"type": "text", "text": content}]
            for item in content:
                if item["type"] == "text":
                    text_parts.append(item["text"])
                elif item["type"] == "image":
                    image_path = item["image"]
                    try:
                        image = Image.open(image_path).convert("RGB")
                        pil_images.append(image)
                    except FileNotFoundError:
                        logger.warning(f"Warning: Image file not found at {image_path}. Skipping.")
        final_image = None
        if pil_images:
            if len(pil_images) == 1:
                final_image = pil_images[0]
            else:
                total_width = sum(img.width for img in pil_images)
                max_height = max(img.height for img in pil_images)

                stitched_image = Image.new("RGB", (total_width, max_height))

                current_x = 0
                for img in pil_images:
                    stitched_image.paste(img, (current_x, 0))
                    current_x += img.width
                final_image = stitched_image

        full_text = " ".join(text_parts)
        prompt_text = f"{full_text}"
        return prompt_text, final_image

    def send_chat_request(self, messages):
        text, _ = self.convert_openai_to_gradio(messages)
        image_list = []
        messages_for_processor = copy.deepcopy(messages)
        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                continue
            for item in message["content"]:
                item_type = item.get("type")
                if item_type == "image":
                    image_path = item.get("image")
                    if image_path and isinstance(image_path, str):
                        try:
                            image_list.append(image_path)
                        except FileNotFoundError:
                            logger.warning(f"image not found: {image_path}, skiped")
        output_text = inference_model(
            image_list,
            text,
            self.model,
            self.tokenizer,
            self.image_processor,
            self.context_len,
            conv_mode="phi",
            max_new_tokens=4096,
        )
        return output_text, None, None, None


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "4"
    model_path = "models/TinyChart-3B-768/"

    # miniforge3/condabin/conda run -n chartqa --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_tinychart.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result", model_name))
    eval.run_one_prediction_local()
