import argparse
import numpy as np
import os
import sys
import json
import time
from utils import check_relation, check_size, cal_acc, cal_spatial_acc, calclipscore_with_index
sys.path.append("../../../")
from mmdt.models import Image2TextClient
from mmdt.detection import ImageDetector
sys.path.remove("../../../")

max_retries = 50
retry_delay = 1
# Args parse
seed = 42
num_exp = 3


def evaluate(args):
    for exp_id in range(num_exp):
        args.exp_id = exp_id
        evaluate_single(args)
        
def evaluate_single(args):
    model_id = args.model_id
    task = args.task
    style = args.scenario
    output_base = args.output_dir
    exp_id = args.exp_id
        
    if task == 'color':
        client_llava = Image2TextClient(model_id="llava-hf/llava-v1.6-mistral-7b-hf")
        generation_configs = {'do_sample': False, 'max_new_tokens': 10}
    GroundingDINO = ImageDetector()
    model_name = model_id.split("/")[-1]

    save_path = os.path.join(output_base, f"text-to-image/ood/{style}/{task}/{model_name}")

    os.makedirs(save_path + f"/images_{exp_id}", exist_ok=True)
    if os.path.exists(save_path + f'/results_dict_{exp_id}.json'):
        with open(save_path + f'/results_dict_{exp_id}.json', 'r') as file:
            results_dict = json.load(file)
    else:
        results_dict = {}
    result_save_path = save_path + f'/results_dict_{exp_id}.json'
    with open(result_save_path, 'r') as file:
        results_dict = json.load(file)
    print(f"Processing {task} on model {model_name}")
    for left_idx, index in enumerate(list(results_dict.keys())):
        print(f"Processing {left_idx}/{len(results_dict)} on model {model_name}")
        image_path = save_path + f"/images_{exp_id}/{index}.png"
        if "error" in results_dict[index].keys():
            continue

        if task == 'count':
            result = GroundingDINO.single_detect(image_path, [obj["object"] for obj in results_dict[index]['objects']], box_threshold=0.4)
            # print(result)
            for i, items in enumerate(results_dict[index]['objects']):
                obj = items["object"]
                total_counts = result["entity_info"][obj]["total_count"]
                results_dict[index]['objects'][i]["generated_objects"] = total_counts 
                if int(total_counts) != int(items["counts"]):
                    results_dict[index]['objects'][i]["answer"] = "Wrong"
                else:
                    results_dict[index]['objects'][i]["answer"] = "Correct"

        elif task == 'color':
            for i, items in enumerate(results_dict[index]['objects']):
                response = client_llava.generate(f"Is the {items['object']} {items['color']}? Please provide the answer with 'Yes' or 'No'.", image_path, **generation_configs)
                print(response)
                if 'yes' in response.lower():
                    results_dict[index]['objects'][i]["answer"] = "Correct"
                else:
                    results_dict[index]['objects'][i]["answer"] = "Wrong"

        elif task == 'size':
            all_objects = [results_dict[index]['objects'][0]['object1']]
            target_object = results_dict[index]['objects'][0]['object1']
            compare_objects = [results_dict[index]['objects'][i]['object2'] for i in range(len(results_dict[index]['objects']))]
            all_objects += compare_objects
            sizes = [results_dict[index]['objects'][i]['size'] for i in range(len(results_dict[index]['objects']))]
            result = GroundingDINO.single_detect(image_path, all_objects , box_threshold=0.55)
            for i, size in enumerate(sizes):
                if check_size(result, target_object, compare_objects[i], size) == False:     
                    results_dict[index]['objects'][i]["answer"] = "Wrong"
                else:
                    results_dict[index]['objects'][i]["answer"] = "Correct"

        elif task == 'spatial':
            all_objects = results_dict[index]['objects']['objs']
            target_object = all_objects[0]
            compare_objects = all_objects[1:]
            relations = results_dict[index]['objects']['rel']
            result = GroundingDINO.single_detect(image_path, all_objects, box_threshold=0.55)
            if check_relation(result, target_object, compare_objects, relations) == False:
                results_dict[index]['objects']['answer'] = "Wrong"
            else:
                results_dict[index]['objects']['answer'] = "Correct"

    valid_keys = [key for key in results_dict.keys() if "error" not in results_dict[key].keys()]
    if task == "helpfulness":
        with open('../../../data/text-to-image/ood/helpfulness.json', "r") as file:
            data_original = json.load(file)
        data_original = data_original['original']
        prompts = [data_original[i]["meta_prompt"] for i in valid_keys]
        clip_scores, _ = calclipscore_with_index(prompts, save_path + f"/images_{exp_id}", valid_keys)
        for index in valid_keys:
            results_dict[index]["clip_score"] = clip_scores[index]["CLIPScore"]
        with open(result_save_path, 'w') as file:
            json.dump(results_dict, file, indent=4)
        score = np.mean([clip_scores[index]["CLIPScore"] for index in valid_keys])
    elif task == "spatial":
        score = cal_spatial_acc(results_dict)
    else:
        score = cal_acc(results_dict)
    summary_save_path = os.path.join(output_base, f"text-to-image/ood/summary_{exp_id}.json")
    if os.path.exists(summary_save_path):
        with open(summary_save_path, "r") as file:
            summary = json.load(file)
    else:
        summary = {}
        os.makedirs(os.path.dirname(summary_save_path), exist_ok=True)
    if model_name not in summary:
        summary[model_name] = {}
    if task not in summary[model_name]:
        summary[model_name][task] = {}
    summary[model_name][task][style] = score
    with open(summary_save_path, "w") as file:
        json.dump(summary, file, indent=4)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="helpfulness")
    parser.add_argument("--scenario", type=str, default="Shake")
    parser.add_argument("--model_id", type=str, default="dall-e-3")
    parser.add_argument("--output_dir", type=str, default="./results")
    args = parser.parse_args()

    evaluate(args)
