import requests
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration

model_id = "llava-hf/llava-1.5-7b-hf"

prompt_1 = "USER: <image>\nWhat does this image show?\nASSISTANT:"
prompt_2 = "USER: <image> <image> \nWhat is the difference between these two images?\nASSISTANT:"
image_file_1 = "image1.png"
image_file_2 = "image2.png"
model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_flash_attention_2=True).to(0)
processor = AutoProcessor.from_pretrained(model_id)
raw_image_1 = Image.open(image_file_1)
raw_image_2 = Image.open(image_file_2)
inputs = processor([prompt_1, prompt_2], [raw_image_1, raw_image_1, raw_image_2], padding=True, return_tensors="pt").to(0, torch.float16)
import pdb

pdb.set_trace()
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.batch_decode(output, skip_special_tokens=True))
