import os
import time
import concurrent.futures
import google.generativeai as genai
from IPython.display import Image
import json
from tqdm import tqdm
import argparse

safety_settings = [
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    },
]


prompt_cap = "Given the image below, provide a detailed caption that accurately describes the scene, including key elements and any notable actions taking place. "\
            "Focus on capturing the essence of the image in a single, coherent sentence. If there are any specific details such as emotions, colors, or unique objects, please include them in the description."

def generate_caption(item, api_key):
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-pro-vision', safety_settings=safety_settings)
    caption_retries = 0
    caption = ""
    img = Image(item['save_path'])
    while caption_retries < 3 and caption == "":
        try:
            response = model.generate_content([img, prompt_cap])
            caption = response.text
        except Exception as e:
            #print(item)
            print(e)
            if '429' in str(e) or '400' in str(e):
                print('429')
                time.sleep(60)
                continue
            caption_retries += 1
            time.sleep(30)
            if caption_retries == 3:
                caption = "error"
                print('caption error:', e)
    return caption

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--api_key', type=str,help='gemini key')
    parser.add_argument('--json_path', type=str,help='meta json path')
    parser.add_argument('--output_path', type=str,help='output meta json path')
    parser.add_argument('--begin', type=int,help='begin generate index',default=50)
    args = parser.parse_args()

    with open(args.json_path, 'r') as f:
        data_original = json.load(f)
    if os.path.exists(args.output_path):
        with open(args.output_path, 'r') as f:
            data = json.load(f)
    else:
        data = data_original

    with tqdm(total=len(data_original)) as pbar:
        for i in range(len(data_original)):
            if i <= args.begin:
                pbar.update(1)
                continue
            #import pdb; pdb.set_trace() 
            if i >= len(data):
                data.append(data_original[i])
            if 'gemini_caption' not in data[i] or data[i]['gemini_caption'] == 'error':
                data[i]['gemini_caption'] = generate_caption(data[i], args.api_key)
            else:
                pbar.update(1)
                continue
                

            pbar.update(1)
            #if i % 10 == 0:
            with open(args.output_path, 'w') as f:
                json.dump(data, f)
            time.sleep(200)
    with open(args.output_path, 'w') as f:
        json.dump(data, f)

