import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import base64
import os, sys
from loguru import logger

sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
)
from labelstudio.common_prompts import encode_image

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.eval import EVAL


class TransformerAgent:

    def __init__(self, model_path) -> None:
        # Load the base Qwen2-VL model
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.float32, device_map="auto", trust_remote_code=True
        )
        logger.info(f"Model loaded on device: {self.model.device}")
        self.max_pixels = 5120 * 28 * 28
        self.processor = AutoProcessor.from_pretrained(model_path, max_pixels=self.max_pixels)

    def send_chat_request(self, messages):
        # Preparation for inference
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")
        generated_ids = self.model.generate(**inputs, max_new_tokens=4096)
        generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return output_text[0], None, None, None


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,9"
    model_path = "models/Awaker2.5-R1"

    # miniforge3/condabin/conda run -n chartqa --live-stream pip install anthropic volcengine-python-sdk[ark] qwen-vl-utils[decord]
    # miniforge3/condabin/conda run -n chartqa --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_awaker.py
    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path=model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result", model_name))
    eval.run_one_prediction_local()
