import torch
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from loguru import logger
import requests, os, json, copy, sys
from PIL import Image

sys.path.append("project/chartqa/env/projects/Ming")
# project/chartqa/env/projects/Ming
from modeling_bailingmm import BailingMMNativeForConditionalGeneration


class TransformerAgent:
    def __init__(self, model_path) -> None:
        self.device = "auto"
        self.model = BailingMMNativeForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,  # Use bfloat16 for memory efficiency
            attn_implementation="flash_attention_2",
            load_image_gen=True,
            low_cpu_mem_usage=True,  # Minimize CPU memory during loading
        ).to(self.device)
        self.processor = AutoProcessor.from_pretrained(".", trust_remote_code=True)

    def send_chat_request(self, messages):
        messages_for_processor = copy.deepcopy(messages)
        for message in messages_for_processor:
            if not isinstance(message.get("content"), list):
                if not isinstance(message.get("content"), str):
                    message["content"] = [{"type": "text", "text": ""}]
                else:
                    message["content"] = [{"type": "text", "text": message.get("content")}]
                continue
            message["role"] = "HUMAN"
        text = self.processor.apply_chat_template(messages_for_processor, add_generation_prompt=True)
        image_inputs, video_inputs, audio_inputs = self.processor.process_vision_info(messages_for_processor)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            audios=audio_inputs,
            return_tensors="pt",
        )
        inputs = inputs.to(self.model.device)
        for k in inputs.keys():
            if k == "pixel_values" or k == "pixel_values_videos" or k == "audio_feats":
                inputs[k] = inputs[k].to(dtype=torch.bfloat16)

        generation_config = GenerationConfig.from_dict({"no_repeat_ngram_size": 10})
        generated_ids = self.model.generate(
            **inputs,
            max_new_tokens=4096,
            use_cache=True,
            eos_token_id=self.processor.gen_terminator,
            generation_config=generation_config,
        )
        generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        return output_text, None, None, None


if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,9"
    # model_path = " models/Ming-Lite-Omni/"  # NOTE
    model_path = "inclusionAI/Ming-Lite-Omni"

    # miniforge3/condabin/conda run -n kimi --live-stream python project/chartqa/src/evaluation/chartqa/src/eval_open/eval_kimi.py
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils.eval import EVAL

    model_name = model_path.split("/")[-1]
    agent = TransformerAgent(model_path)
    eval = EVAL(agent, os.path.join("project/chartqa/result", model_name))
    eval.run_one_prediction_local()
