import json, random
import time
import os

from tqdm import trange
from copy import deepcopy


od = {
    0 : 'first',
    1 : 'second',
    2 : 'third'
}

def generate_a_category(object_list : dict, 
                        color_list: list,
                        dic : dict,
                        all_objects : list,
                        max_number: int = 3):
    object_number = random.randint(1, max_number)
    object_type = random.choice(list(object_list.keys()))
    now_color_list = object_list.pop(object_type)
    if len(now_color_list) == 0:
        now_color_list = color_list

    dic[object_type] = {
        'number' : object_number,
        'objects' : []
    }
    for i in range(object_number):
        temp_dic = {
            'id' : i,
            'color' : random.choice(now_color_list),
            'relation' : {
                'left' : [],
                'right' : [],
                'above' : [],
                'below' : [],
            }
        }
        all_objects.append((object_type, i))
        dic[object_type]['objects'].append(temp_dic)

    return dic, all_objects, object_list, object_number


def generate_all_category(object_list, color_list, max_number, total_number):
    object_list = deepcopy(object_list)
    total_object_numbers = total_number
    dic = {}
    all_objects = []
    while total_object_numbers > 0:
        dic, all_objects, object_list, chosen_number = generate_a_category(object_list, color_list, dic, all_objects, min(total_object_numbers, max_number))
        total_object_numbers -= chosen_number
    return dic, all_objects

def topology(dic, ind, total_number):
    node_list = []
    cnt = 0
    for i in range(total_number):
        if i not in ind:
            node_list.append(i)
            cnt += 1
    p = 0
    while True:
        if p >= len(node_list):
            break
        if node_list[p] not in dic:
            p += 1
            continue
        for i in dic[node_list[p]]:
            ind[i] -= 1
            if ind[i] == 0:
                cnt += 1
                node_list.append(i)
        p += 1
    return cnt == total_number

def generate_direction(object_list, ori_obj_dic):
    obj_dic = deepcopy(ori_obj_dic)
    lr_dic = {}
    lr_ind = {}
    ud_dic = {}
    ud_ind = {}

    node_list = []
    for i in range(len(object_list)):
        obj_cat, obj_id = object_list[i]
        for j in range(i+1, len(object_list)):
            n_obj_cat, n_obj_id = object_list[j]
            p = random.random()
            if p < 0.05: #left
                if i not in lr_dic:
                    lr_dic[i] = [j]
                else:
                    lr_dic[i].append(j)
                if j not in lr_ind:
                    lr_ind[j] = 1
                else:
                    lr_ind[j] += 1
                obj_dic[obj_cat]['objects'][obj_id]['relation']['left'].append((n_obj_cat, n_obj_id))
            elif p < 0.1: # right:
                if j not in lr_dic:
                    lr_dic[j] = [i]
                else:
                    lr_dic[j].append(i)
                if i not in lr_ind:
                    lr_ind[i] = 1
                else:
                    lr_ind[i] += 1
                obj_dic[obj_cat]['objects'][obj_id]['relation']['right'].append((n_obj_cat, n_obj_id))
            elif p < 0.15: #above
                if i not in ud_dic:
                    ud_dic[i] = [j]
                else:
                    ud_dic[i].append(j)
                if j not in ud_ind:
                    ud_ind[j] = 1
                else:
                    ud_ind[j] += 1
                obj_dic[obj_cat]['objects'][obj_id]['relation']['above'].append((n_obj_cat, n_obj_id))
            elif p < 0.2: #below
                if j not in ud_dic:
                    ud_dic[j] = [i]
                else:
                    ud_dic[j].append(i)
                if i not in ud_ind:
                    ud_ind[i] = 1
                else:
                    ud_ind[i] += 1
                obj_dic[obj_cat]['objects'][obj_id]['relation']['below'].append((n_obj_cat, n_obj_id))
    if not topology(lr_dic, lr_ind, len(object_list)) or not topology(ud_dic, ud_ind, len(object_list)):
        return False, ori_obj_dic
    else:
        return True, obj_dic

def generate_with_dic(dic):
    basic_prompt = 'A photo realistic image of '
    for id, object_category in enumerate(dic):
        object_number = dic[object_category]['number']
        if id == 0:
            basic_prompt += f'{object_number} {object_category}'
        else:
            basic_prompt += f', {object_number} {object_category}'
    basic_prompt += '. '
    all_prompts = [basic_prompt]
    for object_category in dic:
        for id, object_item in enumerate(dic[object_category]['objects']):
            object_prompt = f'The {od[id]} {object_category} is {object_item['color']}'
            relations = object_item['relation']
            for obj_cat, obj_id in relations['left']:
                object_prompt += f', on the left of the {od[obj_id]} {obj_cat}'
            for obj_cat, obj_id in relations['right']:
                object_prompt += f', on the right of the {od[obj_id]} {obj_cat}'
            for obj_cat, obj_id in relations['above']:
                object_prompt += f', above the {od[obj_id]} {obj_cat}'
            for obj_cat, obj_id in relations['below']:
                object_prompt += f', below the {od[obj_id]} {obj_cat}'
            object_prompt += '. '
            all_prompts.append(object_prompt)
    return ''.join(all_prompts)

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description='Generate benchmark data')
    #parser.add_argument('--data_path', type=str, default='object_names.json', help='Path to the object names JSON file')
    parser.add_argument('--output_path', type=str, default='generated_data/', help='Output directory for generated data')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--size', type=int, default=10000, help='Number of samples to generate')
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    os.makedirs(args.output_path, exist_ok=True)
    with open('m3t2ibench_code/object_names.json', 'r') as f:
        objects = json.load(f)
    random.seed(args.seed)
    colors = ['green', 'red', 'yellow', 'brown', 'black', 'white', 'blue']


    for i in trange(args.size):
        total_object_number = random.randint(2, 5)
        max_number = 3

        ret = False
        while not ret:
            obj_dic, all_objects = generate_all_category(objects, colors, max_number, total_object_number)
            ret, obj_dic = generate_direction(all_objects, obj_dic)
            dic = {
                'total_number' : total_object_number,
                'obj_dict' : obj_dic,
                'obj_list' : all_objects,
            }
            prompt = generate_with_dic(dic['obj_dict'])
            dic['prompt'] = prompt
            with open(os.path.join(args.output_path, f'{i}.json'), 'w') as f:
                json.dump(dic, f, indent = 4)
                