import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
import requests

model_name_or_path = '/mnt/lustrenew/mllm_safety-shared/models/huggingface/meta-llama/Llama-3.2-11B-Vision-Instruct'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = AutoProcessor.from_pretrained(model_name_or_path)
model = AutoModelForVision2Seq.from_pretrained(model_name_or_path).to(device)

url = "https://aerospaceamerica.aiaa.org/wp-content/uploads/2023/06/0723_Aero_Starship-1200x675.jpg"

image = Image.open(requests.get(url, stream=True).raw)
input_text = 'Please introduce this image.'

messages_list = [
            [{
                "content": [
                    {"text": None, "type": "image", "index": 0},  # single image
                    {"text": input_text, "type": "text", "index": None},
                ],
                "role": "user",
            }],
            [{
                "content": [
                    {"text": None, "type": "image", "index": 0},  # single image
                    {"text": input_text, "type": "text", "index": None},
                ],
                "role": "user",
            }],
            [{
                "content": [
                    {"text": None, "type": "image", "index": 0},  # single image
                    {"text": input_text, "type": "text", "index": None},
                ],
                "role": "user",
            }]
        ]

images = [[image]] * len(messages_list)
texts = [processor.apply_chat_template(messages, tokenize=False, add_generation_prompt = True) for messages in messages_list]
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
outputs = model.generate(
            **batch,
            max_new_tokens=200,
            do_sample=False,
            temperature=None,
            top_p = None
        )

decoded_results = processor.batch_decode(outputs[:, batch['input_ids'].shape[-1]:], skip_special_tokens=True)
breakpoint()
for decoded_result in decoded_results:
    print(decoded_result)
    print('='*100)
