import os
import torch
from PIL import Image
import numpy as np
import json
import math
from tqdm import tqdm
import random
from typing import List



from lavis.models import load_model_and_preprocess
# from transformers import Blip2Processor, Blip2ForConditionalGeneration

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main(
    ann_filename: str = 'text/source.json',
    output_filename: str = 'text/target.txt',
    img_dir: str = '/data/dataset/Flickr30k',
    max_batch_size: int = 1,
):
    # loads BLIP-2 pre-trained model
    ### Ver 1
    model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device)
    
    ### Ver 2
    # processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    # model = Blip2ForConditionalGeneration.from_pretrained(
    #     "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
    # )
    model.to(device)

    # read anns
    with open(ann_filename, 'r') as f:
        anns = json.load(f)
    len_anns = len(anns)
    imgs = [ann['image'] for ann in anns]
    caps = [ann['caption'] for ann in anns]
    image_ids = [ann['image_id'] for ann in anns]

    new_anns = []
    empty_prompt_count = 0

    N = 0
    for i in tqdm(range(len_anns)):

        prompts = []
        
        # for j in range(max_batch_size):
        
        original_prompt = caps[i]
        image_id = image_ids[i]
        img_file = imgs[i]
        img_path = os.path.join(img_dir, img_file)

        question = f"The original description for the image is: '{original_prompt}'. \n" + \
            "Please rewrite the description for the image in a more detailed and descriptive manner. >>"
        # question = f"Rewrite the description for the image. >>"
        raw_image = Image.open(img_path).convert('RGB')

        ### Ver 1
        # prepare the image
        image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)

        prompt = {
            "image": image,
            "prompt": question
        }
        generated_text = model.generate(prompt)[0]

        ### Ver 2
        # inputs = processor(images=raw_image, text=question, return_tensors="pt").to("cuda", torch.float16)
        # print(inputs)
        # generated_ids = model.generate(
        #     **inputs,
        #     # max_new_tokens=10,
        #     # max_length=77,
        #     )
        # print("Generated IDs: ", str(generated_ids))
        # # print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
        # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

        print(f"Original caption: {original_prompt}")
        print(f"New caption: {generated_text}")

        ann = {
            "image": img_file,
            "image_id": image_id,
            "caption": generated_text
        }
        new_anns.append(ann)

        
    print("DONE")
    print(f"Empty prompt count: {empty_prompt_count}")


if __name__=="__main__":
    main(
        ann_filename='/data/dataset/dataset_json/data/flickr30k_train.json',
        output_filename='/data/dataset/dataset_json/data_rewrite/flickr30k_train_blip2_rewrite.json',
    )