import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from loguru import logger
import requests, os, json, copy, sys
from PIL import Image
from qwen_vl_utils import process_vision_info


class TransformerAgent:
    def __init__(self, model_path) -> None:
        self.device = "auto"
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path, torch_dtype="auto", device_map=self.device
        )  # _attn_implementation="flash_attention_2", sdpa
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

    def send_chat_request(self, messages):
        messages_for_processor = copy.deepcopy(messages)
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.model.device)

        # Inference: Generation of the output
        generated_ids = self.model.generate(**inputs, max_new_tokens=4096)
        generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        # logger.info(output_text)
        return output_text[0], None, None, None


if __name__ == "__main__":
    # os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,5"
    model_path = "models/Qwen2.5-VL-32B-Instruct"
    # model_path = "models/Qwen2.5-VL-7B-Instruct"

    # miniforge3/condabin/conda run -n qwen --live-stream CUDA_VISIBLE_DEVICES=2,3,4,5 python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_qwen.py
    # miniforge3/condabin/conda run -n chartqa --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_internvl.py && miniforge3/condabin/conda run -n qwen --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_qwen.py

    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result/cot", model_name))
    eval.run_one_prediction_local()
