import json
import argparse
from tqdm import tqdm
from PIL import Image
import os
import time
from src.api import *
from src.utils import *
from src.vqa_score import *


def gemini_pred(input):
    try:
        output = input.text
    except:
        output = 'safety problem'
    return output

def api_func(args):

    if args.model == 'gpt4v':
        model_run = lambda x: gpt4v(x)
        get_pred = lambda x: x['choices'][0]['message']['content']
    elif args.model == 'gpt4o':
        model_run = lambda x: gpt4o(x)
        get_pred = lambda x: x['choices'][0]['message']['content']
    elif args.model == 'claude3':
        model_run = lambda x: claude3(x)
        get_pred = lambda x: x.content[0].text
    elif args.model == 'gemini':
        model_run = lambda x: gemini(x)
        get_pred = gemini_pred
    elif args.model == 'opus':
        model_run = lambda x: opus(x)
        get_pred = lambda x: x.content[0].text
    
    return model_run, get_pred
 
def main(args):

    model_run, get_pred = api_func(args)

    with open(args.json_path, "r") as f:
        data = json.load(f)

    pbar = tqdm(total=5*len(data), ncols=100, desc=f'inferencing {args.model}')

    error = 0
    
    for d in data:

        if 'prediction' in d:
            pbar.update(1)
        else:
            question = d['question']
            image_path = f"{args.image_dir}/full/{d['id']}.png"
            input = {'image_path': [image_path], 'question': f'<image>{question}'}
            answer = model_run(input)
            if 'error' in answer:
                error += 1
                print(answer)
                if error > 100:
                    break
            else:
                d['prediction'] = get_pred(answer)
                pbar.update(1)
        
        with open(args.json_path, "w") as f:
            json.dump(data, f, indent=4)
        
        if 'h_perception_prediction' in d:
            pbar.update(1)
        else:
            h_perception_question = d['h_perception_question']
            image_path = f"{args.image_dir}/full/{d['id']}.png"
            input = {'image_path': [image_path], 'question': f'<image>{h_perception_question}'}
            try:
                h_perception_answer = model_run(input)
            except:
                continue
            if 'error' in h_perception_answer:
                error += 1
                if error > 100:
                    break
            else:
                d['h_perception_prediction'] = get_pred(h_perception_answer)
                pbar.update(1)
        
        with open(args.json_path, "w") as f:
            json.dump(data, f, indent=4)
        
        if 'l_perception_prediction_tuple' in d:
            pbar.update(3)
        else:
            l_answers = []
            for i in d['l_perception_question_tuple']:
                image_path = f"{args.image_dir}/low_level/{d['id']}.png"
                input = {'image_path': [image_path], 'question': f'<image>{i}'}
                try:
                    l_answer = model_run(input)
                except:
                    continue
                if 'error' in l_answer:
                    error += 1
                    if error > 100:
                        break
                else:
                    l_pred = get_pred(l_answer)
                    l_answers.append(l_pred)
            if len(l_answers) == 3:
                d['l_perception_prediction_tuple'] = l_answers
                pbar.update(3)
        
        with open(args.json_path, "w") as f:
            json.dump(data, f, indent=4)

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        choices=[
            "gpt4v",
            "gpt4o",
            'claude3',
            'gemini',
            'opus'
        ],
    )

    parser.add_argument(
        "--data_dir",
        type=str,
        default="./data",
    )

    args = parser.parse_args()

    args.json_path = f"{args.data_dir}/prediction/{args.model}.json"
    args.image_dir = f"{args.data_dir}/images"

    main(args)
