from pathlib import Path
import numpy as np

gen_layouts_path = Path("build_a_scene/eval/evaluation_set_1/generation")
gen_renders_path = gen_layouts_path / "renders"
gen_jsons_path = gen_layouts_path / "json"

cons_layouts_path = Path("build_a_scene/eval/evaluation_set_1/consistency")
cons_renders_path = cons_layouts_path / "renders"
cons_jsons_path = cons_layouts_path / "json"


def get_relations(relation):
    if relation == "r":
        b1_pos = "left"
        b2_pos = "right"
    elif relation == "l":
        b1_pos = "right"
        b2_pos = "left"
    elif relation == "a":
        b1_pos = "bottom"
        b2_pos = "top"
    return b1_pos, b2_pos


def check_overlap(mask_1, mask_2):
    return (mask_1.bool() & mask_2.bool()).any()


def check_out_of_bounds(mask):
    xx, yy = np.where(mask.cpu())
    return xx.min() == 0 or xx.max() == 511 or yy.min() == 0 or yy.max() == 511


OBJECTS_CATEGORIES = {
    "animals": ["a cat", "a dog", "a horse", "an elephant", "a grizzly bear"],
    "indoor": ["a teddy bear", "a microwave", "a backpack", "an lcd tv", "a sofa", "a chair", "a table", "a bed"],
    "outdoor": ["a car", "a motorcycle", "a backpack", "a bench", "a sofa"],
}

# Width, Depth, Height
ASPECT_RATIOS = {
    "a cat": (0.2, 0.2, 0.4),
    "a dog": (0.3, 0.2, 0.5),
    "a horse": (0.7, 0.3, 0.5),
    "an elephant": (0.9, 0.3, 0.6),
    "a grizzly bear": (0.8, 0.3, 0.5),
    "a teddy bear": (0.3, 0.2, 0.4),
    "a microwave": (0.4, 0.2, 0.25),
    "a backpack": (0.3, 0.1, 0.4),
    "a car": (0.6, 1.2, 0.4),
    "a motorcycle": (0.8, 0.2, 0.35),
    "an lcd tv": (0.5, 0.05, 0.3),
    "a sofa": (0.9, 0.3, 0.4),
    "a chair": (0.25, 0.25, 0.5),
    "a table": (0.5, 0.5, 0.4),
    "a bench": (0.7, 0.3, 0.3),
    "a bed": (0.6, 0.7, 0.25),
}

RELATIONS = {
    "a cat": ("l", "r", "a"),
    "a dog": ("l", "r", "a"),
    "a horse": ("l", "r"),
    "an elephant": ("l", "r"),
    "a grizzly bear": ("l", "r"),
    "a teddy bear": ("l", "r", "a"),
    "a microwave": ("l", "r", "a"),
    "a backpack": ("l", "r", "a"),
    "a car": ("l", "r"),
    "a motorcycle": ("l", "r"),
    "an lcd tv": ("l", "r", "a"),
    "a sofa": ("l", "r"),
    "a chair": ("l", "r"),
    "a table": ("l", "r"),
    "a bench": ("l", "r"),
    "a bed": ("l", "r"),
}

SCENES = [
    ("An empty desert with cloudy sky", ["animals", "outdoor"]),
    ("An empty room with windows and curtains", ["indoor"]),
    ("An empty street", ["outdoor"]),
    ("An empty jungle", ["animals"]),
    ("An empty road", ["animals", "outdoor"]),
    ("An empty studio", ["indoor"]),
    ("An empty beach", ["animals"]),
    ("A snowy landscape", ["outdoor"]),
    ("An empty apartment", ["indoor"]),
]

YOLO8_LABELS = {
    "a cat": 15,
    "a dog": 16,
    "a horse": 17,
    "an elephant": 20,
    "a grizzly bear": 21,
    "a teddy bear": 77,
    "a microwave": 68,
    "a backpack": 24,
    "a car": 2,
    "a motorcycle": 3,
    "an lcd tv": 62,
    "a sofa": 57,
    "a chair": 56,
    "a table": 60,
    "a bench": 13,
    "a bed": 59,
}

if __name__ == "__main__":
    import sys 
    sys.path.append("/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/build_a_scene")
    sys.path.append("/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/build_a_scene/eval")
    from src.utils.scene import seed_everything
    from itertools import chain
    seed_everything(45)
    import json
    import torch
    import random
    import numpy as np
    import matplotlib.pyplot as plt
    from src.utils.scene import DiffusionScene
    from layouts_utils import gen_layouts_path, gen_renders_path, gen_jsons_path, OBJECTS_CATEGORIES, SCENES, ASPECT_RATIOS, RELATIONS, check_out_of_bounds, check_overlap
    from PIL import Image

    num_infer_steps = 20
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    dtype = torch.float16

    # Make directories to save the layouts, renders for debug and jsons
    for path in [gen_layouts_path, gen_renders_path, gen_jsons_path]:
        if not path.exists():
            path.mkdir(parents=True)        


    num_samples = 100 # Number of layouts to generate
    num_seeds = 5 # Number of seeds to generate for each layout

    counter = 0
    while counter < num_samples:
        layout_dict = {}

        # Build the scene
        scene_size = 5
        print("==> Scene size: ", scene_size)
        scene = DiffusionScene(scene_size=scene_size)
        camera_angle = np.random.randint(70,80)
        print("==> Camera Angle: ", camera_angle)
        scene.move_camera(rotation_angle=camera_angle,rotation_axis=[1,0,0], translation=[0,0,0])
        scene.build_floor(scale_x=2, scale_y=4, floor_offset=-scene_size)
        scene.set_pipe(None, "", num_infer_steps, device, dtype)
        scene_choice = random.choices(SCENES)[0]
        scene_prompt = scene_choice[0]
        scene_categories =  scene_choice[1]
        objects = list(chain(*[OBJECTS_CATEGORIES[c] for c in scene_categories]))

        layout_dict["scene_size"] =  scene_size
        layout_dict["camera_angle"] = camera_angle
        layout_dict["scene"] = scene_prompt
        layout_dict["seeds"] = np.random.randint(0,10000,num_seeds).tolist()
        print("")

        # Define the first box 
        prompt_b1 = random.choice(objects)
        print(f"==> First object is `{prompt_b1}`" )
        aspect_b1 = np.array(ASPECT_RATIOS[prompt_b1])
        print("==> First object aspect ratio is", aspect_b1)
        size_b1 = scene_size * aspect_b1 *2
        print("==> First object size is", size_b1)
        z_b1 = scene_size * np.random.uniform(1.2,2)
        origin_b1 = [0,z_b1,0]
        print("==> First object origin is", origin_b1)


        print("")

        # Define the Second box 
        prompt_b2 = random.choice(objects)
        print(f"==> Second object is `{prompt_b2}`" )
        aspect_b2 = np.array(ASPECT_RATIOS[prompt_b2])
        print("==> Second object aspect ratio is", aspect_b2)
        size_b2 = scene_size * aspect_b2 *2
        print("==> Second object size is", size_b2)
        origin_b2 = [0,z_b1,0]
        print("==> Second object origin is", origin_b2)

        relation = random.choice(RELATIONS[prompt_b2])

        print("relation is :", relation)
        success = False
        for itr in range(20): # Try for 20 times to find a valid layout
            print(itr)
            step = 1 + random.random()
            if relation == "l": # Left
                origin_b1[0] += step
                origin_b2[0] -= step
            elif relation == "r": # Right
                origin_b1[0] -= step
                origin_b2[0] += step
            elif relation == "a":        
                origin_b2[2] = -scene_size + size_b1[2] + size_b2[2] /2 
                print( origin_b2)            

            try:
                scene.add_box(id="box_1", size=size_b1, origin=origin_b1, prompt=prompt_b1)
                mask_b1, latent_mask_b1, p_image_b1 = scene.get_box_masks(box_id="box_1")
                
                scene.add_box(id="box_2", size=size_b2, origin=origin_b2, prompt=prompt_b2)
                mask_b2, latent_mask_b2, p_image_b2 = scene.get_box_masks(box_id="box_2")
            except:
                continue
            
            # Check if the boxes overlap or out of bounds
            if  check_overlap(mask_b1, mask_b2) or check_out_of_bounds(mask_b1) or check_out_of_bounds(mask_b2)  :
                continue
            else:
                success=True
                break

        if not success:
            continue

        rendered_scene = scene.render()
        img_fname = gen_renders_path / f"{counter:04d}.png"
        # plt.imsave(img_fname, rendered_scene)
        Image.fromarray(rendered_scene).save(img_fname)

        layout_dict["box_1"] = {"prompt": prompt_b1, "aspect_ratio": aspect_b1.tolist(), "size": size_b1.tolist(), "origin": origin_b1, "relation": None}
        layout_dict["box_2"] = {"prompt": prompt_b2, "aspect_ratio": aspect_b2.tolist(), "size": size_b2.tolist(), "origin": origin_b2, "relation": relation}


        file=open(gen_jsons_path / f"{counter:04d}.json","w")
        json.dump(layout_dict,file,indent=2)
        file.close()


        counter += 1
