import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from loguru import logger
import requests, os, json, copy, sys
from PIL import Image
from vllm import LLM, SamplingParams


class TransformerAgent:
    def __init__(self, model_path) -> None:
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.model = LLM(
            model_path,
            trust_remote_code=True,
            max_num_seqs=8,
            max_model_len=131072,
            limit_mm_per_prompt={"image": 256},
            # gpu_memory_utilization=0.75,
            # tensor_parallel_size=2,
        )
        self.sampling_params = SamplingParams(max_tokens=32768, temperature=0.8)

    def extract_thinking_and_summary(self, text: str, bot: str = "◁think▷", eot: str = "◁/think▷") -> str:
        if bot in text and eot not in text:
            return ""
        if eot in text:
            return (
                text[text.index(bot) + len(bot) : text.index(eot)].strip(),
                text[text.index(eot) + len(eot) :].strip(),
            )  # type: ignore
        return "", text  # type: ignore

    def send_chat_request_transformer(self, messages):
        images = []
        messages_for_processor = copy.deepcopy(messages)
        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                continue
            for item in message["content"]:
                item_type = item.get("type")
                if item_type == "image":
                    image_path = item.get("image")
                    if image_path and isinstance(image_path, str):
                        try:
                            images.append(Image.open(image_path))
                        except FileNotFoundError:
                            logger.warning(f"image not found: {image_path}, skiped")
        text = self.processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
        inputs = self.processor(images=images, text=text, return_tensors="pt", padding=True, truncation=True).to(
            self.model.device
        )
        generated_ids = self.model.generate(**inputs, max_new_tokens=32768, temperature=0.8)
        generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        return output_text, None, None, None

    def send_chat_request(self, messages):
        images = []
        messages_for_processor = copy.deepcopy(messages)
        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                continue
            for item in message["content"]:
                item_type = item.get("type")
                if item_type == "image":
                    image_path = item.get("image")
                    if image_path and isinstance(image_path, str):
                        try:
                            images.append(Image.open(image_path))
                        except FileNotFoundError:
                            logger.warning(f"image not found: {image_path}, skiped")
        text = self.processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")

        inputs = [{"prompt": text, "multi_modal_data": {"image": images}}]
        outputs = self.model.generate(inputs, sampling_params=self.sampling_params)
        generated_text = outputs[0].outputs[0].text
        thinking, summary = self.extract_thinking_and_summary(generated_text)
        output_text = summary
        return output_text, None, thinking, None


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
    model_path = "models/Kimi-VL-A3B-Thinking-2506"  # NOTE

    # # miniforge3/condabin/conda run -n kimi --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_kimi.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result", model_name))
    eval.run_one_prediction_local()
