from transformers import Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image
import torch, os, json, time


json_path = './configs/annotations_v3/data/valid_group_all_nor.json'
root_path = './data/merge_all'


with open(json_path, "r") as f:
            data_json = json.load(f)

data_keys = [key for key, _ in data_json.items()]

images_list = [name + "_mask_color.png" for name in data_keys]

json_file_path = 'image_captions.json'
image_captions = {}
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", cache_dir='/data/../cache')
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, cache_dir='/data/../cache').to("cuda")
for i, image in enumerate(images_list):
    base = image.split('_mask_color')[0]
    image_pil = Image.open(os.path.join(root_path, "masks_color", image)).convert("RGB")

    print("load blip2 for image caption...")
    st = time.time()
    inputs = processor(image_pil, return_tensors="pt").to("cuda", torch.float16)
    out = blip_model.generate(**inputs)
    caption = processor.batch_decode(out, skip_special_tokens=True)[0].strip()
    caption = caption.replace("there is ", "")
    caption = caption.replace("close up", "photo")
    et = time.time()
    print(et - st)
    for d in ["black background", "white background"]:
        if d in caption:
            caption = caption.replace(d, "ground")
        print("Caption: ", caption)
        image_captions[base] = caption

with open(json_file_path, 'w') as json_file:
    json.dump(image_captions, json_file, indent=4)
