import os, json
from copy import deepcopy

od = {
    0 : 'first',
    1 : 'second',
    2 : 'third'
}

def construct_prompts(dic):
    basic_positive_prompt = 'This is a photo-realisitc image. '
    basic_negative_prompt = 'This is not a photo-realistic image. '
    whole_fl = 0
    fl = 0
    for id, object_category in enumerate(dic['objects']):
        if 'correct_number' in dic['objects'][object_category]:
            if fl == 0:
                basic_positive_prompt += f'{dic["objects"][object_category]["correct_number"]} {object_category}'
                basic_negative_prompt += f'{dic["objects"][object_category]["wrong_number"]} {object_category}'
            else:
                basic_positive_prompt += f', {dic["objects"][object_category]["correct_number"]} {object_category}'
                basic_negative_prompt += f', {dic["objects"][object_category]["wrong_number"]} {object_category}'
            fl = 1
    if fl == 1:
        whole_fl = 1
        basic_positive_prompt += '. '
        basic_negative_prompt += '. '
    all_positive_prompts = [basic_positive_prompt]
    all_negative_prompts = [basic_negative_prompt]
    for object_category in dic['objects']:
        for id, object_item in enumerate(dic['objects'][object_category]['objects']):
            fl = 0
            if 'color' in object_item:
                object_positive_prompt = f'The {od[id]} {object_category} is {object_item["color"]},'
                object_negative_prompt = f'The {od[id]} {object_category} is not {object_item["color"]},'
                fl = 1
            elif 'correct_color' in object_item:
                object_positive_prompt = f'The {od[id]} {object_category} is {object_item["correct_color"]},'
                object_negative_prompt = f'The {od[id]} {object_category} is {object_item["false_color"]},'
                fl = 1
            else:
                object_positive_prompt = f'The {od[id]} {object_category} is'
                object_negative_prompt = f'The {od[id]} {object_category} is'
            relations = object_item['relation']
            for obj_cat, obj_id in relations['left']:
                object_positive_prompt += f' on the left of the {od[obj_id]} {obj_cat},'
                object_negative_prompt += f' not on the left of the {od[obj_id]} {obj_cat},'
                fl = 1
            for obj_cat, obj_id in relations['right']:
                object_positive_prompt += f' on the right of the {od[obj_id]} {obj_cat},'
                object_negative_prompt += f' not on the right of the {od[obj_id]} {obj_cat},'
                fl = 1
            for obj_cat, obj_id in relations['above']:
                object_positive_prompt += f' above the {od[obj_id]} {obj_cat},'
                object_negative_prompt += f' not above the {od[obj_id]} {obj_cat},'
                fl = 1
            for obj_cat, obj_id in relations['below']:
                object_positive_prompt += f' below the {od[obj_id]} {obj_cat},'
                object_negative_prompt += f' not below the {od[obj_id]} {obj_cat},'
                fl = 1
            if fl == 0:
                object_positive_prompt = ''
                object_negative_prompt = ''
            else:
                whole_fl = 1
                object_positive_prompt += '. '
                object_negative_prompt += '. '
            all_positive_prompts.append(object_positive_prompt)
            all_negative_prompts.append(object_negative_prompt)
    if whole_fl == 0:
        return '.', '.'   
    else:
        return ''.join(all_positive_prompts), ''.join(all_negative_prompts)


def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description='Construct negative prompts')
    parser.add_argument('--root_path', type=str, default='sd35-medium', help='Model name')
    
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_args()
    with open(os.path.join(args.root_path, 'multi_evaluation/best_eval_results.json'), 'r') as f:
        data = json.load(f)
    all_dic = []
    for id, d in enumerate(data):
        with open(os.path.join(args.root_path, f'{id}/metadata.json'), 'r') as f:
            meta = json.load(f)
        full_rep_dic = {
            'meta_dict' : meta,
            'prompt' : meta['prompt'],
            'objects' : {}
            
        }
        rep_dic = {}
        for k in d:
            if not isinstance(d[k], dict):
                continue
            rep_dic[k] = {}
            number = d[k]['number']
            number_bias = d[k]['number_bias']
            if number_bias != 0:
                rep_dic[k]['correct_number'] = number
                
            rep_dic[k]['objects'] = []
            for item in d[k]['objects']:
                item_dic = {'id': item['id']}
                if not item['object_found']:
                    number_bias = -abs(number_bias)
                    item_dic = deepcopy(item)
                else:
                    if not item['color_is_correct']:
                        item_dic['correct_color'] = item['color']
                        item_dic['false_color'] = item['detected_colors'][0]
                    item_dic['relation'] = {}
                    for rel in item['relation']:
                        item_dic['relation'][rel] = [] 
                        rel_list = item['relation'][rel]
                        for obj in rel_list:
                            if obj[-1] != 'relation correct!':
                                item_dic['relation'][rel].append(obj[:2])
                rep_dic[k]['objects'].append(item_dic)
            if number_bias != 0:
                rep_dic[k]['wrong_number'] = number + number_bias
        full_rep_dic['objects'] = rep_dic   
        full_rep_dic['positive_prompt'], full_rep_dic['negative_prompt'] = construct_prompts(deepcopy(full_rep_dic))
        all_dic.append(full_rep_dic)
    with open(os.path.join(args.root_path, 'multi_evaluation/constructed_prompts.json'), 'w') as f:
        json.dump(all_dic, f, indent = 4)
