import os
import sys 
sys.path.append("/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl")
import re
import itertools
import json
from tqdm import tqdm
from PIL import Image
import numpy as np

from src.utils.scene import DiffusionScene
from src.utils.prompt import gen_prompt, edit_prompt, identity_prompt, gen_prompt_2d, gen_prompt_new
from src.utils.vlm import vlm_request, extract_and_parse_json

def check_overlap(mask_1, mask_2):
    return (mask_1.bool() & mask_2.bool()).any()

def check_out_of_bounds(mask):
    xx, yy = np.where(mask.cpu())
    return xx.min() == 0 or xx.max() == 511 or yy.min() == 0 or yy.max() == 511

RELATIONS = {
    "2d": [
        "on the side of", 
        "next to", 
        "near", 
        "on the left of", 
        "on the right of", 
        "on the bottom of",
        "on the top of"
    ],
    "3d": [
        'in front of', 
        'at the back left of', 
        'at the front left of', 
        'behind of', 
        'at the back right of', 
        'at the front right of', 
        'hidden by'
    ]
}

SCENES = [
    ("on the desert", ["animals", "outdoor", "person"]),
    ("in the room", ["indoor", "person"]),
    ("on the street", ["outdoor", "person"]),
    ("in the jungle", ["animals", "person"]),
    ("on the road", ["animals", "outdoor", "person"]),
    ("in the studio", ["indoor", "person"]),
    ("on the beach", ["animals", "person"]),
    ("on a snowy landscape", ["outdoor", "person"]),
    ("in the apartment", ["indoor", "person"]),
    ("in the library", ["indoor", "person"]),
]

SCENES_PROMPT = [
    "on the desert",
    "in the room",
    "on the street",
    "in the jungle",
    "on the road",
    "in the studio",
    "on the beach",
    "on a snowy landscape",
    "in the apartment",
    "in the library",
]

OBJECTS_CATEGORIES = {
    "animals": ['dog', 'mouse', 'sheep', 'cat', 'cow', 'chicken', 'turtle', 'giraffe', 'pig', 'butterfly', 'horse', 'bird', 'rabbit', 'frog', 'fish'],
    "indoor": ['bed', 'desk', 'key', 'chair', 'vase', 'candle', 'cup', 'phone', 'computer', 'bowl', 'sofa', 'balloon', 'plate', 'refrigerator', 'wallet', 'bag', 'painting', 'suitcase', 'table', 'couch', 'clock', 'book', 'lamp', 'television'],
    "outdoor": ["car", "motorcycle", "backpack", "bench", 'train', 'airplane', 'bicycle'],
    "person": ['woman', 'man', 'boy', 'girl'],
}

def generate_scene_combinations():
    """
    Generates all possible combinations of (scene, object1, 3d_relation, object2)
    based on the defined SCENES, OBJECTS_CATEGORIES, and RELATIONS.
    """
    all_combinations = []
    relations_3d = RELATIONS.get("3d", [])

    for scene_name, allowed_categories in SCENES:
        allowed_objects_for_scene = set() # Use a set to avoid duplicates
        for category in allowed_categories:
            if category in OBJECTS_CATEGORIES:
                # Use update to add all items from the list
                allowed_objects_for_scene.update(OBJECTS_CATEGORIES[category])
            else:
                print(f"  Warning: Category '{category}' not found in OBJECTS_CATEGORIES.")

        allowed_objects_list = list(allowed_objects_for_scene)

        # We use permutations because "A in front of B" is different from "B in front of A".
        object_pairs = itertools.permutations(allowed_objects_list, 2)
        # object_pairs = itertools.combination(allowed_objects_list, 2)

        for obj1, obj2 in object_pairs:
            for relation in relations_3d:
                combination = (obj1, relation, obj2)
                if combination not in all_combinations: all_combinations.append(combination)
                # combination = (scene_name, obj1, relation, obj2)
                # all_combinations.append(combination)

    return all_combinations

num = 0
OBJECT_TO_CATEGORY = {}
for category, objects in OBJECTS_CATEGORIES.items():
    for obj in objects:
        OBJECT_TO_CATEGORY[obj] = category
        # # if one object in multi scene
        # if obj not in OBJECT_TO_CATEGORY:
        #     OBJECT_TO_CATEGORY[obj] = []
        # OBJECT_TO_CATEGORY[obj].append(category)

# Get a set of all known objects for efficient searching in the prompt
all_known_objects = set(OBJECT_TO_CATEGORY.keys())

def find_applicable_scenes(prompt):
    """
    Determines which scene prompts are applicable to a given prompt based on the objects mentioned.

    Args:
        prompt: The input text prompt (e.g., "a dog next to a car").

    Returns:
        A list of scene prompt strings (e.g., ["on the street", "on the road"])
        that are compatible with the objects found in the input prompt.
        Returns an empty list if no known objects are found or if no scenes
        allow the combination of object categories found.
    """
    prompt_lower = prompt.lower()
    found_objects = set()

    # Find which known objects are present in the prompt
    # Using regex with word boundaries (\b) to avoid partial matches (e.g., "car" in "carpet")
    for obj in all_known_objects:
        # Create a regex pattern for the object surrounded by word boundaries
        pattern = r'\b' + re.escape(obj) + r'\b'
        if re.search(pattern, prompt_lower):
            found_objects.add(obj)

    if not found_objects:
        print(f"Warning: No known objects found in prompt: '{prompt}'")
        return [] # Cannot determine applicable scenes without known objects

    # Determine the categories of the found objects
    prompt_categories = set()
    for obj in found_objects:
        category = OBJECT_TO_CATEGORY.get(obj)
        if category:
            prompt_categories.add(category)
        # else: # Should not happen if all_known_objects is derived correctly
        #     print(f"Warning: Found object '{obj}' has no category mapping.")

    if not prompt_categories:
        # This case should ideally not be reached if found_objects is not empty
        print(f"Warning: Could not determine categories for found objects: {found_objects}")
        return []

    # print(f"  Found objects: {found_objects}")
    # print(f"  Detected categories: {prompt_categories}")

    # Find scenes where *all* detected categories are allowed
    applicable_scenes = []
    for scene_prompt, allowed_categories in SCENES:
        # Check if the set of prompt categories is a subset of the allowed categories for the scene
        if prompt_categories.issubset(set(allowed_categories)):
            applicable_scenes.append(scene_prompt)

    return applicable_scenes

def json_generation(caption, entities, json_path, data_path):
    if os.path.exists(f'{json_path}/{caption}.json'):
        return

    messages=[
        {
            "role": "user",
            "content": [
                {"type": "text", "text": f"{gen_prompt.replace('<caption>', caption).replace('<entities>', json.dumps(entities))}"},
            ]
        }
    ]
    content = vlm_request(messages)
    answer = content.split('</think>')[-1]
    ans_json = extract_and_parse_json(answer)

    data = {
        'caption': caption,
        'entities': entities,
        'ans_json': ans_json,
        'content': content,
    }
    # print(content)
    with open(f'{json_path}/{caption}.json', 'w') as f:
        json.dump(data, f, indent=4)
    
    scene_generate(ans_json, data_path, caption)

def scene_generate(ans_json, data_path, caption):
    # Define scene parameters 
    scene_size = ans_json['scene_parameters']['scene_size'] / 2
    cam_pitch_angle = 90 - ans_json['scene_parameters']['camera_pitch_angle']
    floor_offset = - scene_size / 5
    floor_scale_x = 1
    floor_scale_y = 1

    # Build the scene    
    scene = DiffusionScene(scene_size=scene_size)
    scene.move_camera(rotation_angle=cam_pitch_angle,rotation_axis=[1,0,0], translation=[0,0,0])# rotation_axis(x,z,y), translation(x, z, y)
    # scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,2.5,-1])# rotation_axis(x,z,y), translation(x, z, y)
    scene.build_floor(scale_x=floor_scale_x, scale_y=floor_scale_y, floor_offset=floor_offset)

    for i, entity in enumerate(ans_json['entity_layout']):
        scene.add_box(id=f"box_{i}", size=entity['size'], origin=entity['position'], prompt=entity['entity_name'])
        # scene.box(f"box_{i}").rotate_left(entity['orient'])
        # mask_b2, latent_mask_b2, p_image_b2 = scene.get_box_masks(box_id="box_2")

    depth_all = scene.render(single=True, floor=True, depth_max=4*scene_size)
    for j, depth in enumerate(depth_all):
        Image.fromarray(depth).save(f'{data_path}/{caption}/render_depth_{j}.png')
    
    depth_all = scene.render_bas()
    for j, depth in enumerate(depth_all):
        Image.fromarray(depth).save(f'{data_path}/{caption}/bas_depth_{j}.png')

if __name__ == "__main__":
    data_path = 'data/render'
    json_path = 'data/json'
    os.makedirs(f"{data_path}", exist_ok=True)
    os.makedirs(f"{json_path}", exist_ok=True)

    all_combinations = generate_scene_combinations()
    new_all_combinations = []
    for combination in all_combinations:
        obj1, relation, obj2 = combination
        caption = f"a {obj1} {relation} a {obj2}"

        # if relation not in ['behind of', 'at the back right of', 'at the front right of']:
        if relation not in ['at the back left of', 'at the front left of', 'hidden by', 'in front of']:
            continue

        # Check if processing is already complete (JSON file exists)
        json_output_path = f'{json_path}/{caption}.json'
        if not os.path.exists(json_output_path):
            new_all_combinations.append(combination)

    all_combinations = sorted(new_all_combinations)
    print(f"Generated {len(all_combinations)} total combinations.")

    for obj1, relation, obj2 in tqdm(all_combinations):
        caption = f"a {obj1} {relation} a {obj2}"
        entities = [obj1, obj2]
        json_generation(caption, entities, json_path)
        os.makedirs(f"{data_path}/{caption}", exist_ok=True)
        json_generation(caption, entities, json_path, data_path)

        