import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, PaliGemmaForConditionalGeneration
import base64
import os, sys
from loguru import logger
from PIL import Image
import requests
from openai import OpenAI

sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
)
from labelstudio.common_prompts import encode_image
import os
import sys
import math
import re
import json
import random
import uuid, copy
from PIL import Image, ImageDraw, ImageColor

import torch
from transformers import AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration


class TransformerAgent:
    def __init__(self, model_path) -> None:
        self.device = torch.device("cuda")
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )

    def send_chat_request(self, messages):
        pil_images = []
        messages_for_processor = copy.deepcopy(messages)

        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                if not isinstance(message.get("content"), str):
                    message["content"] = [{"type": "text", "text": ""}]
                else:
                    message["content"] = [{"type": "text", "text": message.get("content")}]
                continue
            new_content = []
            for item in message["content"]:
                item_type = item.get("type")
                if item_type == "image":
                    image_path = item.get("image")
                    if image_path and isinstance(image_path, str):
                        try:
                            pil_images.append(Image.open(image_path))
                            new_content.append({"type": "image", "url": image_path})
                        except FileNotFoundError:
                            logger.warning(f"image not found: {image_path}, skiped")
                elif item_type == "text":
                    text_content = item.get("text")
                    if isinstance(text_content, str) and text_content.strip():
                        new_content.append(item)
            message["content"] = new_content
        # logger.info(messages_for_processor)
        inputs = self.processor.apply_chat_template(
            messages_for_processor, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
        ).to(self.model.device)
        # logger.info(inputs)
        output_ids = self.model.generate(**inputs, max_new_tokens=4096)
        # logger.info(output_ids)
        generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
        output_text = self.processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        return output_text[0], None, None, None


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
    model_path = "models/ChartSketcher-72B"  # NOTE

    # miniforge3/condabin/conda run -n chartqa --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_chartsketcher.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result", model_name))
    eval.run_one_prediction_local()
