import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer, Glm4vForConditionalGeneration
import base64
import os, sys
from loguru import logger
import requests
import math
import re
import json
import random
import uuid, copy
from PIL import Image, ImageDraw, ImageColor


class TransformerAgentGLM4V:
    def __init__(self, model_path) -> None:
        self.device = torch.device("cuda")
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = (
            AutoModelForCausalLM.from_pretrained(
                model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
            )
            .to(self.device)
            .eval()
        )

    def convert_openai_to_prompt_and_concated_image_and_systemmessage(self, messages):
        text_parts = []
        pil_images = []
        system_message = ""
        for message in messages:
            if message["role"] == "system":
                system_message = message.get("content", "")
                continue
            content = message.get("content", [])
            if not isinstance(content, list):
                content = [{"type": "text", "text": content}]
            for item in content:
                if item["type"] == "text":
                    text_parts.append(item["text"])
                elif item["type"] == "image":
                    image_path = item["image"]
                    try:
                        image = Image.open(image_path).convert("RGB")
                        pil_images.append(image)
                    except FileNotFoundError:
                        logger.warning(f"Warning: Image file not found at {image_path}. Skipping.")
        # return text_parts, pil_images
        final_image = None
        if pil_images:
            if len(pil_images) == 1:
                final_image = pil_images[0]
            else:
                total_width = sum(img.width for img in pil_images)
                max_height = max(img.height for img in pil_images)

                stitched_image = Image.new("RGB", (total_width, max_height))

                current_x = 0
                for img in pil_images:
                    stitched_image.paste(img, (current_x, 0))
                    current_x += img.width
                final_image = stitched_image

        full_text = " ".join(text_parts)
        prompt_text = f"{full_text}"
        return prompt_text, final_image, system_message

    def send_chat_request(self, messages):
        prompt_text, final_image, system_message = self.convert_openai_to_prompt_and_concated_image_and_systemmessage(
            messages
        )
        messages_for_processor = [
            {"role": "system", "content": system_message},
            {"role": "user", "image": final_image, "content": prompt_text},
        ]

        inputs = self.tokenizer.apply_chat_template(
            messages_for_processor,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True,
        ).to(self.model.device)
        gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
        with torch.no_grad():
            outputs = self.model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs["input_ids"].shape[1] :]
            output_text = self.tokenizer.decode(outputs[0])
            return output_text, None, None, None


class TransformerAgentGLM4_1V_Thinking:
    def __init__(self, model_path) -> None:
        self.device = torch.device("cuda")
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = Glm4vForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa"
        )

    def send_chat_request(self, messages):
        messages_for_processor = copy.deepcopy(messages)
        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                if not isinstance(message.get("content"), str):
                    message["content"] = [{"type": "text", "text": ""}]
                else:
                    message["content"] = [{"type": "text", "text": message.get("content")}]
                continue
            new_content = []
            for item in message["content"]:
                item_type = item.get("type")
                if item_type == "image":
                    image_path = item.get("image")
                    if image_path and isinstance(image_path, str):
                        try:
                            new_content.append({"type": "image", "url": image_path})
                        except FileNotFoundError:
                            logger.warning(f"image not found: {image_path}, skiped")
                elif item_type == "text":
                    text_content = item.get("text")
                    if isinstance(text_content, str) and text_content.strip():
                        new_content.append(item)
            message["content"] = new_content
        # logger.info(messages_for_processor)
        inputs = self.processor.apply_chat_template(
            messages_for_processor, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
        ).to("cuda")
        generated_ids = self.model.generate(**inputs, max_new_tokens=4096)
        generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return output_text[0], None, None, None


if __name__ == "__main__":
    # os.environ["CUDA_VISIBLE_DEVICES"] = "3"
    model_path = "models/GLM-4.1V-9B-Thinking"  # NOTE

    # miniforge3/condabin/conda run -n glm --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_glm.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgentGLM4_1V_Thinking(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result/cot", model_name))
    eval.run_one_prediction_local()
