import requests
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
import os 


img_dir = 'trigger_image/'
image_path_list = os.listdir(img_dir)
image_path_list = [img_dir + p for p in image_path_list]
image_path_list.sort()


device = torch.device('cuda:2')
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)

with open('blip2.txt', 'w') as f:
    for i,image_path in enumerate(image_path_list):
        image = Image.open(image_path).convert('RGB')  
        inputs = processor(image, return_tensors="pt").to(device, torch.float16)
        generated_ids = model.generate(**inputs, max_new_tokens=15)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        print(generated_text)
        f.write('{}.{}\n'.format(i,generated_text))
        f.flush()
f.close()