from transformers import AutoModelForVision2Seq, AutoProcessor
import torch
from PIL import Image as PIL_Image
import requests

model_id = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/meta-llama/Llama-Guard-3-11B-Vision"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# image = PIL_Image.open("<path/to/image>").convert("RGB")
url1 = "https://images.nubilefilms.com/videos/whats_yours_is_mine_with_chanel_camryn_freya_parker/samples/cover960.jpg"
image1 = PIL_Image.open(requests.get(url1, stream=True).raw).convert("RGB")

# Sample messages for batch inference
url2 = "https://aerospaceamerica.aiaa.org/wp-content/uploads/2023/06/0723_Aero_Starship-1200x675.jpg"
image2 = PIL_Image.open(requests.get(url2, stream=True).raw)

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Describe the image."},
            {"type": "image"},
        ],
    }
]

input_prompt = processor.apply_chat_template(
    [conversation, conversation], return_tensors="pt"
)

inputs = processor(text=input_prompt, images=[[image1], [image2]], return_tensors="pt").to(model.device)

prompt_len = len(inputs['input_ids'][0])
output = model.generate(
    **inputs,
    max_new_tokens=20,
    pad_token_id=0,
)

generated_tokens = output[:, prompt_len:]

print(input_prompt[0])
print(processor.decode(generated_tokens[0]))
print(processor.decode(generated_tokens[1]))
