import os
import time
import google.generativeai as genai
from IPython.display import Image
import json
from tqdm import tqdm
import argparse

safety_settings = [
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    },
]



def generate_answer(item, api_key, choose_query):
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-pro-vision', safety_settings=safety_settings)
    retries = 0
    answer = ""
    question = item[choose_query]
    #question = item['question']
    #img = Image('/mnt/petrelfs/zhangyongting/zyt/research/RLHF/RLHF_DATASET/LAION/' + item['image'][0])
    img = Image(item['save_path'])
    while retries < 3 and answer == "":
        try:
            response = model.generate_content([img, question])
            answer = response.text
        except Exception as e:
            print(e)
            if '429' in str(e) or '400' in str(e):
                print('429')
                genai.configure(api_key=api_key)
                model = genai.GenerativeModel('gemini-pro-vision', safety_settings=safety_settings)
                time.sleep(60)
                continue
            retries += 1
            time.sleep(10)
            if retries == 3:
                answer = "error"
                print('answer error:', e)
    return answer

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--api_key', type=str,help='gemini key')
    parser.add_argument('--json_path', type=str,help='meta json path')
    parser.add_argument('--output_path', type=str,help='output meta json path')
    parser.add_argument('--choose_query', type=str,help='output meta json path')
    parser.add_argument('--begin', type=int,help='begin generate index',default=50)
    
    args = parser.parse_args()

    with open(args.json_path, 'r') as f:
        data_original = json.load(f)
    if os.path.exists(args.output_path):
        with open(args.output_path, 'r') as f:
            data = json.load(f)
    else:
        data = data_original

    with tqdm(total=len(data_original)) as pbar:
        for i in range(len(data_original)):
            if i <= args.begin:
                pbar.update(1)
                continue
            #import pdb; pdb.set_trace() 
            if i >= len(data):
                data.append(data_original[i])
            if 'gemini_answer' not in data[i] or data[i]['gemini_answer'] == 'error':
                data[i]['gemini_answer'] = generate_answer(data[i], args.api_key, args.choose_query)
                time.sleep(60)
            pbar.update(1)
            #if i % 10 == 0:
            with open(args.output_path, 'w') as f:
                json.dump(data, f)
            #time.sleep(0.5)
    with open(args.output_path, 'w') as f:
        json.dump(data, f)

