import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
import base64
import os, sys
from loguru import logger
import requests
import math
import re
import json
import random
import uuid, copy
from PIL import Image, ImageDraw, ImageColor


class TransformerAgentInternVL:
    def __init__(self, model_path) -> None:
        self.device = "auto"
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_path, device_map=self.device, torch_dtype=torch.bfloat16, trust_remote_code=True
        )

    def send_chat_request(self, messages):
        messages_for_processor = copy.deepcopy(messages)
        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                if not isinstance(message.get("content"), str):
                    message["content"] = [{"type": "text", "text": ""}]
                else:
                    message["content"] = [{"type": "text", "text": message.get("content")}]
                continue
            new_content = []
            for item in message["content"]:
                item_type = item.get("type")
                if item_type == "image":
                    image_path = item.get("image")
                    if image_path and isinstance(image_path, str):
                        try:
                            new_content.append({"type": "image", "url": image_path})
                        except FileNotFoundError:
                            logger.warning(f"image not found: {image_path}, skiped")
                elif item_type == "text":
                    text_content = item.get("text")
                    if isinstance(text_content, str) and text_content.strip():
                        new_content.append(item)
            message["content"] = new_content
        inputs = self.processor.apply_chat_template(
            messages_for_processor,
            padding=True,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.model.device, dtype=torch.bfloat16)

        output = self.model.generate(**inputs, max_new_tokens=4096)

        decoded_outputs = self.processor.batch_decode(output, skip_special_tokens=True)
        # logger.info(decoded_outputs)
        output_text = decoded_outputs[0].split("assistant\n")[-1]  # ["generated_text"]
        return output_text, None, None, None


if __name__ == "__main__":

    model_path = "models/InternVL3-14B-hf"  # NOTE
    # model_path = "models/InternVL3-38B-hf"  # NOTE
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgentInternVL(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result/cot", model_name + "_new"))
    eval.run_one_prediction_local()
