import re
import os
import io
import tempfile
import json
import copy
import time
import torch
import argparse

from tqdm import tqdm
from PIL import Image
from glob import glob
from torchvision import transforms
from torchvision.transforms import ToPILImage

from PIL import Image
import os
import io
import tempfile
import base64

# Function to encode the image
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

def convert_pil_to_path(pil_image):
    image_bytes = io.BytesIO()
    pil_image.save(image_bytes, format='JPEG')
    image_bytes.seek(0)
    with tempfile.NamedTemporaryFile(delete=False, suffix='.jpeg') as temp_file:
        temp_file.write(image_bytes.getvalue())
        temp_filename = temp_file.name
    
    return temp_filename

def load_llava(args, device):
    print(f'loading {args.model} model')
    from llava.model.builder import load_pretrained_model
    from llava.mm_utils import get_model_name_from_path
    from llava.eval.run_llava import eval_model
    from llava.conversation import conv_templates
    from llava.mm_utils import (
        process_images,
        tokenizer_image_token,
        get_model_name_from_path,
    )
    from llava.constants import (
        IMAGE_TOKEN_INDEX,
        DEFAULT_IMAGE_TOKEN,
        DEFAULT_IM_START_TOKEN,
        DEFAULT_IM_END_TOKEN,
        IMAGE_PLACEHOLDER,
    )
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        args.model_path, None, get_model_name_from_path(args.model_path))
    return {"model": model, 
            "tokenizer": tokenizer, 
            "image_processor": image_processor, 
            "model_name": model_name,
            "image_token_se": DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN,
            "IMAGE_PLACEHOLDER": IMAGE_PLACEHOLDER,
            "DEFAULT_IMAGE_TOKEN": DEFAULT_IMAGE_TOKEN,
            "conv_templates": conv_templates,
            "tokenizer_image_token": tokenizer_image_token,
            "IMAGE_TOKEN_INDEX": IMAGE_TOKEN_INDEX,
            "process_images": process_images,
            }
    
def load_llava_RLHF(args, device):
    print(f'loading {args.model} model')
    from llava.model import LlavaLlamaForCausalLM
    from llava.constants import (
        IMAGE_TOKEN_INDEX,
        DEFAULT_IMAGE_TOKEN,
        DEFAULT_IM_START_TOKEN,
        DEFAULT_IM_END_TOKEN,
        DEFAULT_IMAGE_PATCH_TOKEN,
    )
    from llava.conversation import conv_templates, SeparatorStyle
    from llava.model.builder import load_pretrained_model
    from llava.utils import disable_torch_init
    from llava.mm_utils import (
        tokenizer_image_token,
        get_model_name_from_path,
        KeywordsStoppingCriteria,
        process_images
    )
    from PIL import Image
    import math
    from peft import PeftModel

    from transformers import (
        AutoTokenizer,
        BitsAndBytesConfig,
    )
    bits = 16
    dtype = torch.bfloat16
    compute_dtype = torch.bfloat16
    model = LlavaLlamaForCausalLM.from_pretrained(
        os.path.join(args.model_path, 'sft_model'),
        device_map=device,
        torch_dtype=dtype,
        load_in_4bit=(bits == 4),
        load_in_8bit=(bits == 8),
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=(bits == 4),
            load_in_8bit=(bits == 8),
            llm_int8_threshold=6.0,
            llm_int8_skip_modules=["mm_projector", "lm_head"],
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        ),
    )
    model = PeftModel.from_pretrained(
        model,
        os.path.join(args.model_path, 'rlhf_lora_adapter_model')
    )
    from transformers import (
        AutoTokenizer,
        BitsAndBytesConfig,
    )
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.model_path, 'sft_model'), use_fast=False)
    mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
    mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
    if mm_use_im_patch_token:
        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
    if mm_use_im_start_end:
        tokenizer.add_tokens(
            [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
        )
    model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model()
    vision_tower.to(device="cuda", dtype=compute_dtype)
    image_processor = vision_tower.image_processor
    return {"model": model,
            "tokenizer": tokenizer,
            "image_processor": image_processor,
            "mm_use_im_start_end": mm_use_im_start_end,
            "mm_use_im_patch_token": mm_use_im_patch_token,
            "DEFAULT_IM_START_TOKEN": DEFAULT_IM_START_TOKEN,
            "DEFAULT_IM_END_TOKEN": DEFAULT_IM_END_TOKEN,
            "DEFAULT_IMAGE_TOKEN": DEFAULT_IMAGE_TOKEN,
            "conv_templates":conv_templates,
            "IMAGE_TOKEN_INDEX": IMAGE_TOKEN_INDEX,
            "tokenizer_image_token": tokenizer_image_token
            }
    
def load_llava_RLAIF(args, device):
    from llava.rlaifc_chat import RLAIFVChat
    return {"model": RLAIFVChat(args.model_path)}

def load_llava_HALVA(args, device):
    from llava_halva.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
    from llava_halva.conversation import conv_templates, SeparatorStyle
    from llava_halva.utils import disable_torch_init
    from llava_halva.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
    from llava_halva.model.builder import load_pretrained_model
    # model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base_path, model_name)
    return {"model": model,
            "tokenizer": tokenizer,
            "image_processor": image_processor,
            "model_name": model_name,
            "DEFAULT_IM_START_TOKEN": DEFAULT_IM_START_TOKEN,
            "DEFAULT_IMAGE_TOKEN": DEFAULT_IMAGE_TOKEN,
            "DEFAULT_IM_END_TOKEN": DEFAULT_IM_END_TOKEN,
            "IMAGE_TOKEN_INDEX": IMAGE_TOKEN_INDEX,
            "tokenizer_image_token": tokenizer_image_token,
            "process_images": process_images,
            "conv_templates": conv_templates
            }
    
def load_llava_OPA(args, device):
    print(f'loading {args.model} model')
    from llava.model import LlavaLlamaForCausalLM
    from llava.constants import (
        IMAGE_TOKEN_INDEX,
        DEFAULT_IMAGE_TOKEN,
        DEFAULT_IM_START_TOKEN,
        DEFAULT_IM_END_TOKEN,
        DEFAULT_IMAGE_PATCH_TOKEN,
    )
    from llava.conversation import conv_templates, SeparatorStyle
    from llava.model.builder import load_pretrained_model
    from llava.utils import disable_torch_init
    from llava.mm_utils import (
        tokenizer_image_token,
        get_model_name_from_path,
        KeywordsStoppingCriteria,
        process_images
    )
    from PIL import Image
    import math
    from peft import PeftModel

    from transformers import (
        AutoTokenizer,
        BitsAndBytesConfig,
    )
    bits = 16
    dtype = torch.bfloat16
    compute_dtype = torch.bfloat16
    model = LlavaLlamaForCausalLM.from_pretrained(
        args.model_base_path,
        device_map=device,
        torch_dtype=dtype,
        load_in_4bit=(bits == 4),
        load_in_8bit=(bits == 8),
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=(bits == 4),
            load_in_8bit=(bits == 8),
            llm_int8_threshold=6.0,
            llm_int8_skip_modules=["mm_projector", "lm_head"],
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        ),
    )
    model = PeftModel.from_pretrained(
        model,
        args.model_path
    )
    from transformers import (
        AutoTokenizer,
        BitsAndBytesConfig,
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_base_path, use_fast=False)
    mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
    mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
    if mm_use_im_patch_token:
        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
    if mm_use_im_start_end:
        tokenizer.add_tokens(
            [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
        )
    model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model()
    vision_tower.to(device="cuda", dtype=compute_dtype)
    image_processor = vision_tower.image_processor
    return {"model": model,
            "tokenizer": tokenizer,
            "image_processor": image_processor,
            "mm_use_im_start_end": mm_use_im_start_end,
            "mm_use_im_patch_token": mm_use_im_patch_token,
            "DEFAULT_IM_START_TOKEN": DEFAULT_IM_START_TOKEN,
            "DEFAULT_IM_END_TOKEN": DEFAULT_IM_END_TOKEN,
            "DEFAULT_IMAGE_TOKEN": DEFAULT_IMAGE_TOKEN,
            "conv_templates":conv_templates,
            "IMAGE_TOKEN_INDEX": IMAGE_TOKEN_INDEX,
            "tokenizer_image_token": tokenizer_image_token
            }
    
def load_qwen_vl2(args, device):
    print(f'loading {args.model} model')
    from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
    from qwen_vl_utils import process_vision_info
    # enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        args.model_path, 
        torch_dtype="auto", 
        # torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
        device_map=device
    )
    processor = AutoProcessor.from_pretrained(args.model_path)
    return {"model": model, 
            "processor": processor,
            "process_vision_info": process_vision_info}
    
def load_qwen_vl2_5(args, device):
    print(f'loading {args.model} model')
    from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
    from qwen_vl_utils import process_vision_info
    # enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                args.model_path, 
                torch_dtype="auto", 
                device_map="auto")
    processor = AutoProcessor.from_pretrained(args.model_path)
    return {"model": model, 
            "processor": processor,
            "process_vision_info": process_vision_info}

def load_intern(args, device):
    print(f'loading {args.model} model')
    from lmdeploy import pipeline, TurbomindEngineConfig
    from lmdeploy.vl import load_image
    pipe = pipeline(args.model_path, backend_config=TurbomindEngineConfig(session_len=8192))
    return {"pipeline": pipe,
            "load_image": load_image}
    
def load_deepseek(args, device):
    print(f'loading {args.model} model')
    from transformers import AutoModelForCausalLM
    from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
    from deepseek_vl2.utils.io import load_pil_images
    vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(args.model_path)
    tokenizer = vl_chat_processor.tokenizer
    vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
    vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
    return {"vl_gpt": vl_gpt, 
            "tokenizer": tokenizer, 
            "vl_chat_processor": vl_chat_processor, 
            "load_pil_images": load_pil_images}
    
def load_gpt_4v(args, device):
    print(f'loading {args.model} model')
    return {"api_key": "XXXXXXXXXXXXXXX",
            "gpt_model": "gpt-4o"}

MODEL_LOADERS = {
    "llava-opa": load_llava_OPA,
    "llava-rlhf": load_llava_RLHF,
    "llava-rlaif": load_llava_RLAIF,
    "llava-halva": load_llava_HALVA,
    "llava": load_llava,
    "qwen": load_qwen_vl2_5,
    "deepseek": load_deepseek,
    "intern": load_intern,
    "gpt_4": load_gpt_4v
}

def load_model(args, device):
    for key, loader_func in MODEL_LOADERS.items():
        if key in args.model.lower():
            return loader_func(args, device)
    raise ValueError(f"Unsupported model: {args.model}")

def inference(model_data, args, logger, image, prompt=None, device=None):
    """
    Run inference on the given image or prompt.
    Args:
        model_data: Model data used for inference.
        args: Arguments passed to the inference function.
        logger: Logger for logging.
        image_path: Path to the image or a PIL image. Can be a string or a PIL Image.
        prompt: Text prompt for the model (optional).
        device: The device on which to run the model.
    Returns:
        Inference result.
    """
    # If image_path is a PIL image, convert it to a file path
    # if isinstance(image_path, Image.Image):
    #     image_path = convert_pil_to_path(image_path, image_type)
        # logger.info(f"Converted PIL image to temporary file: {image_path}")
    if "llava" in args.model.lower():
        if "rlhf" in args.model.lower() or "opa" in args.model.lower():
            model = model_data["model"]
            tokenizer = model_data["tokenizer"]
            image_processor = model_data["image_processor"]
            mm_use_im_start_end = model_data["mm_use_im_start_end"]
            mm_use_im_patch_token = model_data["mm_use_im_patch_token"]
            DEFAULT_IM_START_TOKEN = model_data["DEFAULT_IM_START_TOKEN"]
            DEFAULT_IM_END_TOKEN = model_data["DEFAULT_IM_END_TOKEN"]
            DEFAULT_IMAGE_TOKEN = model_data["DEFAULT_IMAGE_TOKEN"]
            conv_templates = model_data["conv_templates"]
            tokenizer_image_token = model_data["tokenizer_image_token"]
            IMAGE_TOKEN_INDEX = model_data["IMAGE_TOKEN_INDEX"]
            if model.config.mm_use_im_start_end:
                qs = (
                    DEFAULT_IM_START_TOKEN
                    + DEFAULT_IMAGE_TOKEN
                    + DEFAULT_IM_END_TOKEN
                    + "\n"
                    + prompt
                )
            else:
                qs = DEFAULT_IMAGE_TOKEN + "\n" + prompt
            conv = conv_templates["llava_v0"].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt_input = conv.get_prompt()
            input_ids = (
            tokenizer_image_token(
                prompt_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
            )
            .unsqueeze(0)
            .cuda()
            )
            image_tensor = image_processor.preprocess(image, return_tensors="pt")[
                "pixel_values"
            ][0]
            try:
                with torch.inference_mode():
                        output_ids = model.generate(
                            input_ids,
                            images=image_tensor.unsqueeze(0).to(dtype=torch.bfloat16).cuda(),
                            image_sizes=[image.size],
                            do_sample=False,
                            temperature=0,
                            top_p=None,
                            num_beams=1,
                            max_new_tokens=512,
                            use_cache=True,
                        )
                return tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
            except Exception as e:
                print(f"Error processing image {image}: {e}")
                torch.cuda.empty_cache()
                logger.info(f"Error processing image {image}: {e}")
        elif "rlaif" in args.model.lower():
            try:
                model = model_data["model"]
                inputs = {"image": image, "question": prompt}
                answer = model.chat(inputs)
                return answer
            except Exception as e:
                print(f"Error processing image {image}: {e}")
                torch.cuda.empty_cache()
                logger.info(f"Error processing image {image}: {e}")
        elif "halva" in args.model.lower():
            model_name = model_data["model_name"]
            model = model_data["model"]
            tokenizer = model_data["tokenizer"]
            image_processor = model_data["image_processor"]
            DEFAULT_IM_START_TOKEN = model_data["DEFAULT_IM_START_TOKEN"]
            DEFAULT_IMAGE_TOKEN = model_data["DEFAULT_IMAGE_TOKEN"]
            DEFAULT_IM_END_TOKEN = model_data["DEFAULT_IM_END_TOKEN"]
            IMAGE_TOKEN_INDEX = model_data["IMAGE_TOKEN_INDEX"]
            tokenizer_image_token = model_data["tokenizer_image_token"]
            process_images = model_data["process_images"]
            conv_templates = model_data["conv_templates"]
            
            if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
                conv_mode = "v1" + '_mmtag'
                print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
            
            if model.config.mm_use_im_start_end:
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
            else:
                qs = DEFAULT_IMAGE_TOKEN + '\n' + prompt

            conv = conv_templates["v1"].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            # prompt = conv.get_prompt()
                
            input_ids = tokenizer_image_token(conv.get_prompt(), tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device='cuda', non_blocking=True)
            
            image_tensor = process_images([image], image_processor, model.config)[0].unsqueeze(0)
            try:
            # print("image_tensor shape:", image_tensor.shape)
                with torch.inference_mode():
                        output_ids = model.generate(
                            input_ids,
                            images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
                            do_sample=False,
                            temperature=0,
                            top_p=None,
                            num_beams=1,
                            max_new_tokens=512,
                            use_cache=True
                        )
                input_token_len = input_ids.shape[1]
                n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
                if n_diff_input_output > 0:
                    print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')

                outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
                outputs = outputs.strip()
                return outputs
            except Exception as e:
                print(f"Error processing image {image}: {e}")
                torch.cuda.empty_cache()
                logger.info(f"Error processing image {image}: {e}")
        else:
            model_name = model_data["model_name"]
            model = model_data["model"]
            tokenizer = model_data["tokenizer"]
            image_processor = model_data["image_processor"]
            image_token_se = model_data["image_token_se"]
            IMAGE_PLACEHOLDER = model_data["IMAGE_PLACEHOLDER"]
            DEFAULT_IMAGE_TOKEN = model_data["DEFAULT_IMAGE_TOKEN"]
            conv_templates = model_data["conv_templates"]
            tokenizer_image_token = model_data["tokenizer_image_token"]
            IMAGE_TOKEN_INDEX =  model_data["IMAGE_TOKEN_INDEX"]
            process_images = model_data["process_images"]

            if IMAGE_PLACEHOLDER in prompt:
                if model.config.mm_use_im_start_end:
                    prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
                else:
                    prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
            else:
                if model.config.mm_use_im_start_end:
                    prompt = image_token_se + "\n" + prompt
                else:
                    prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
            
            if "llama-2" in model_name.lower():
                conv_mode = "llava_llama_2"
            elif "mistral" in model_name.lower():
                conv_mode = "mistral_instruct"
            elif "v1.6-34b" in model_name.lower():
                conv_mode = "chatml_direct"
            elif "v1" in model_name.lower():
                conv_mode = "llava_v1"
            elif "mpt" in model_name.lower():
                conv_mode = "mpt"
            else:
                conv_mode = "llava_v0"
                
            conv = conv_templates[conv_mode].copy()
            conv.append_message(conv.roles[0], prompt)
            conv.append_message(conv.roles[1], None)
            prompt_token = conv.get_prompt()
            input_ids = (
                tokenizer_image_token(prompt_token, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
                .unsqueeze(0)
                .cuda()
            )
            image_tensor = process_images([image], image_processor, model.config)[0]
            try:
                with torch.inference_mode():
                        output_ids = model.generate(
                            input_ids,
                            images=image_tensor.unsqueeze(0).half().cuda(),
                            image_sizes=[image.size],
                            do_sample=False,
                            temperature=0,
                            top_p=None,
                            num_beams=1,
                            max_new_tokens=512,
                            use_cache=True,
                        )
                return tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
            except Exception as e:
                print(f"Error processing image {image}: {e}")
                torch.cuda.empty_cache()
                logger.info(f"Error processing image {image}: {e}")
    elif "gpt" in args.model.lower():
        api_key = model_data["api_key"]
        gpt_model = model_data["gpt_model"]
        # import openai
        from openai import OpenAI
        try:
            client = OpenAI(api_key=api_key)
            base64_image = encode_image(convert_pil_to_path(image))
            response = client.responses.create(
                model=gpt_model,
                input=[
                    {
                        "role": "user",
                        "content": [
                            { "type": "input_text", "text": prompt},
                            {
                                "type": "input_image",
                                "image_url": f"data:image/jpeg;base64,{base64_image}",
                            },
                        ],
                    }
                ],
            )
            return response.output_text
        except Exception as e:
            print(e)
            print('retrying...')
    elif "qwen" in args.model.lower():
        model = model_data["model"]
        processor = model_data["processor"]
        process_vision_info = model_data["process_vision_info"]
        messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": prompt},
            ],
        }
        ]

        # Preparation for inference
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(device)
        # Inference: Generation of the output
        try:
            generated_ids = model.generate(**inputs, max_new_tokens=512)
            generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            return output_text[0]
        except Exception as e:
            print(f"Error processing image {image}: {e}")
            torch.cuda.empty_cache()
            logger.info(f"Error processing image {image}: {e}")
    elif "deepseek" in args.model.lower():
        vl_gpt = model_data["vl_gpt"]
        tokenizer = model_data["tokenizer"]
        vl_chat_processor = model_data["vl_chat_processor"]
        load_pil_images = model_data["load_pil_images"]
        conversation = [
            {
                "role": "<|User|>",
                "content": "This is image:<image>\n"+prompt,
                "images": [convert_pil_to_path(image)],
            },
            {"role": "<|Assistant|>", "content": ""},
        ]
        try:
            # load images and prepare for inputs
            pil_images = load_pil_images(conversation)
            prepare_inputs = vl_chat_processor(
                conversations=conversation,
                images=pil_images,
                force_batchify=True,
                system_prompt=""
            ).to(vl_gpt.device)

            # run image encoder to get the image embeddings
            inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

            # run the model to get the response
            outputs = vl_gpt.language.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=prepare_inputs.attention_mask,
                pad_token_id=tokenizer.eos_token_id,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                max_new_tokens=512,
                do_sample=False,
                use_cache=True
            )
            answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False)
            return answer
        except Exception as e:
            print(f"Error processing image {image}: {e}")
            torch.cuda.empty_cache()
            logger.info(f"Error processing image {image}: {e}")
    elif "intern" in args.model.lower():
        model = model_data["pipeline"]
        load_image = model_data["load_image"]
        prompts = [(prompt, image)]
        try:
            response = model(prompts)
            return response[0].text
        except Exception as e:
            print(f"Error processing image {image}: {e}")
            torch.cuda.empty_cache()
            logger.info(f"Error processing image {image}: {e}")
    else:
        raise ValueError(f"Unsupported model: {args.model}")