import os
import sys 
sys.path.append("/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl")
import numpy as np
import itertools
import json
from tqdm import tqdm
from PIL import Image

from src.utils.scene import DiffusionScene
from src.utils.prompt import gen_prompt, edit_prompt, identity_prompt
from src.utils.vlm import vlm_request, extract_and_parse_json

def check_overlap(mask_1, mask_2):
    return (mask_1.bool() & mask_2.bool()).any()

def check_out_of_bounds(mask):
    xx, yy = np.where(mask.cpu())
    return xx.min() == 0 or xx.max() == 511 or yy.min() == 0 or yy.max() == 511

RELATIONS = {
    "2d": [
        "on the side of", 
        "next to", 
        "near", 
        "on the left of", 
        "on the right of", 
        "on the bottom of",
        "on the top of"
    ],
    "3d": [
        'in front of', 
        'at the back left of', 
        'at the front left of', 
        # 'behind of', 
        # 'at the back right of', 
        # 'at the front right of', 
        'hidden by'
    ]
}

SCENES = [
    ("on the desert", ["animals", "outdoor", "person"]),
    ("in the room", ["indoor", "person"]),
    ("on the street", ["outdoor", "person"]),
    ("in the jungle", ["animals", "person"]),
    ("on the road", ["animals", "outdoor", "person"]),
    ("in the studio", ["indoor", "person"]),
    ("on the beach", ["animals", "person"]),
    ("on a snowy landscape", ["outdoor", "person"]),
    ("in the apartment", ["indoor", "person"]),
    ("in the library", ["indoor", "person"]),
]
SCENES_PROMPT = [
    "on the desert",
    "in the room",
    "on the street",
    "in the jungle",
    "on the road",
    "in the studio",
    "on the beach",
    "on a snowy landscape",
    "in the apartment",
    "in the library",
]

OBJECTS_CATEGORIES = {
    "animals": ['dog', 'mouse', 'sheep', 'cat', 'cow', 'chicken', 'turtle', 'giraffe', 'pig', 'butterfly', 'horse', 'bird', 'rabbit', 'frog', 'fish'],
    "indoor": ['sofa', 'desk', 'key', 'chair', 'vase', 'candle', 'cup', 'phone', 'computer', 'bowl', 'sofa', 'balloon', 'plate', 'refrigerator', 'wallet', 'bag', 'painting', 'suitcase', 'table', 'couch', 'clock', 'book', 'lamp', 'television'],
    "outdoor": ["car", "motorcycle", "backpack", "bench", 'train', 'airplane', 'bicycle'],
    "person": ['woman', 'man', 'boy', 'girl'],
}

def find_nonzero_bounding_box(vector):
  """
  检测numpy向量（数组）中非零区域的边界框。

  Args:
    vector: 一个 NumPy 数组。

  Returns:
    如果向量中存在非零元素，则返回一个包含 (x_min, y_min, x_max, y_max) 的元组。
    如果向量中所有元素都为零，则返回 None。
  """
  # 检查输入是否为 NumPy 数组
  if not isinstance(vector, np.ndarray):
    raise TypeError("输入必须是 NumPy 数组")

  # 检查数组维度是否为 2
  if vector.ndim != 2:
      raise ValueError("输入数组必须是二维的")

  # 找到所有非零元素的索引
  non_zero_indices = np.nonzero(vector)

  # non_zero_indices 是一个包含两个数组的元组：
  # 第一个数组是行索引 (y 坐标)
  # 第二个数组是列索引 (x 坐标)
  y_indices = non_zero_indices[0]
  x_indices = non_zero_indices[1]

  # 检查是否存在非零元素
  if len(y_indices) == 0:
    # 如果没有非零元素，则返回 None
    return None

  # 计算 x 和 y 坐标的最小值和最大值
  y_min = np.min(y_indices)
  y_max = np.max(y_indices)
  x_min = np.min(x_indices)
  x_max = np.max(x_indices)

  return (x_min, y_min, x_max, y_max)

def entity_center(x_min, y_min, x_max, y_max, shape, step=0.1, c_max=0.15, c_min=0.1):
    h, w = shape
    if x_min>h*c_max and y_min>w*c_max and x_max<h*(1-c_max) and y_max<w*(1-c_max):
        return step
    if x_min<h*c_min or y_min<w*c_min or x_max>h*(1-c_min) or y_max>w*(1-c_min):
        return -step * 5
    else:
        return 0

def generate_scene(ans_json):
    scene_size = ans_json['scene_parameters']['scene_size'] / 2
    cam_pitch_angle = 90
    floor_scale_x = 1
    floor_scale_y = 1

    y_min = 100
    y_max = 0
    for i, entity in enumerate(ans_json['entity_layout']):
        y_min = min(y_min, entity['position'][1] - entity['size'][2]/2)
        y_max = max(y_max, entity['position'][1] + entity['size'][2]/2)
    floor_offset = - (y_max + y_min) / 2

    x_min = 100
    x_max = 0
    for i, entity in enumerate(ans_json['entity_layout']):
        x_min = min(x_min, entity['position'][0] - entity['size'][0]/2)
        x_max = max(x_max, entity['position'][0] + entity['size'][0]/2)
    x_mean = (x_max + x_min) / 2
    for i, entity in enumerate(ans_json['entity_layout']):
        entity['position'][0] -= x_mean

    # Build the scene    
    scene = DiffusionScene(scene_size=scene_size, fov=(60,60))
    scene.move_camera(rotation_angle=cam_pitch_angle,rotation_axis=[1,0,0], translation=[0,0,0])# rotation_axis(x,z,y), translation(x, z, y)
    # scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,-2*scene_size,0])# rotation_axis(x,z,y), translation(x, z, y)
    scene.build_floor(scale_x=floor_scale_x, scale_y=floor_scale_y, floor_offset=floor_offset)

    for i, entity in enumerate(ans_json['entity_layout']):
        scene.add_box(id=f"box_{i}", size=entity['size'], origin=entity['position'], prompt=entity['entity_name'])
        # scene.box(f"box_{i}").rotate_left(entity['orient'])
        # mask_b2, latent_mask_b2, p_image_b2 = scene.get_box_masks(box_id="box_2")

    num = 0
    total_move = 0
    depth_all = scene.render(single=True, floor=False, render_floor=False, depth_max=4*scene_size)
    x_min, y_min, x_max, y_max = find_nonzero_bounding_box(depth_all[-1])
    move = entity_center(x_min, y_min, x_max, y_max, depth_all[-1].shape)
    while move != 0 and num < 40:
        scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,move,0])# rotation_axis(x,z,y), translation(x, z, y)
        depth_all = scene.render(single=True, floor=False, render_floor=False, depth_max=4*scene_size)
        x_min, y_min, x_max, y_max = find_nonzero_bounding_box(depth_all[-1])
        move = entity_center(x_min, y_min, x_max, y_max, depth_all[-1].shape)
        num += 1
        total_move += move
    print(num)

    depth_all = scene.render(single=True, floor=False, depth_max=4*scene_size)
    depth_all_bas = scene.render_bas()
    return depth_all, depth_all_bas, total_move

if __name__ == "__main__":
    data_path = 'data/render_1'
    json_path = 'data/json'
    os.makedirs(f"{data_path}", exist_ok=True)
    os.makedirs(f"{json_path}", exist_ok=True)

    json_list = os.listdir(json_path)
    # json_list = json_list[0::4]
    for file in tqdm(json_list, desc="🚀 Loading conditions", total=len(json_list)):
        try:
            with open(f'{json_path}/{file}', 'r') as f:
                condition = json.load(f)
            caption = condition['caption']
            entities = condition['entities']
            ans_json = condition['ans_json']
            # print(caption)

            if os.path.exists(f'{data_path}/{caption}'):
                continue
            else:
                os.makedirs(f"{data_path}/{caption}", exist_ok=True)

            depth_all, depth_all_bas, total_move = generate_scene(ans_json)

            for j, depth in enumerate(depth_all):
                Image.fromarray(depth).save(f'{data_path}/{caption}/render_depth_{j}.png')
            
            a=1

            for j, depth in enumerate(depth_all):
                Image.fromarray(depth).save(f'{data_path}/{caption}/bas_depth_{j}.png')

            with open(f'{data_path}/{caption}/move.txt', 'w') as f:
                f.write(str(total_move) + '\n')

        except Exception as e:
            print(e)
        