import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
from loguru import logger
import requests, os, json, copy, sys
from PIL import Image


class TransformerAgent:
    def __init__(self, model_path) -> None:
        self.device = "cuda"
        self.model = AutoModelForVision2Seq.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            _attn_implementation="flash_attention_2",
        ).to(self.device)
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

    def send_chat_request(self, messages):
        image_list = []
        messages_for_processor = copy.deepcopy(messages)
        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                if not isinstance(message.get("content"), str):
                    message["content"] = [{"type": "text", "text": ""}]
                else:
                    message["content"] = [{"type": "text", "text": message.get("content")}]
                continue
            new_content = []
            for item in message["content"]:
                item_type = item.get("type")
                if item_type == "image":
                    image_path = item.get("image")
                    if image_path and isinstance(image_path, str):
                        try:
                            new_content.append({"type": "image", "url": image_path})
                            image_list.append(load_image(image_path))
                        except FileNotFoundError:
                            logger.warning(f"image not found: {image_path}, skiped")
                elif item_type == "text":
                    text_content = item.get("text")
                    if isinstance(text_content, str) and text_content.strip():
                        new_content.append(item)
            message["content"] = new_content
        prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = self.processor(text=prompt, images=image_list, return_tensors="pt")
        inputs = inputs.to(self.device)

        # Generate outputs
        generated_ids = self.model.generate(**inputs, max_new_tokens=4096)
        output_text = self.processor.batch_decode(
            generated_ids,
            skip_special_tokens=True,
        )[0]
        return output_text.split("Assistant:")[-1], None, None, None


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "4"
    model_path = "models/SmolVLM-Instruct"

    # miniforge3/condabin/conda run -n qwen --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_smolvlm.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result", model_name))
    eval.run_one_prediction_local()
