import os
import json
import numpy as np
from PIL import Image
from shapely.geometry import Polygon, Point
from scipy.ndimage import label as nd_label
from collections import defaultdict

def format2(x):
    return "{:.2f}".format(float(x))

# ===== 配置路径 =====
root_dir = "/media/sig/1CC4CB86C4CB611E/sig/qianluo/InteriorGS"
output_root = "/media/sig/1CC4CB86C4CB611E/sig/qianluo/2D_Semantic_Map/results_complete_v3"
os.makedirs(output_root, exist_ok=True)

# ===== 批量处理每个场景 =====
for subdir in sorted(os.listdir(root_dir)):
    scene_dir = os.path.join(root_dir, subdir)
    if not os.path.isdir(scene_dir):
        continue

    scene_name = os.path.basename(scene_dir)
    out_json = os.path.join(output_root, f"2D_Semantic_Map_{scene_name}_Complete.json")
    out_png = os.path.join(output_root, f"2D_Semantic_Map_{scene_name}_Complete.png")  # 预留

    # 如果已经存在，则跳过
    if os.path.exists(out_json):
        print(f"[跳过] {out_json} 已存在")
        continue

    occ_json_path = os.path.join(scene_dir, "occupancy.json")
    labels_json_path = os.path.join(scene_dir, "labels.json")
    occ_png_path = os.path.join(scene_dir, "occupancy.png")

    if not (os.path.isfile(occ_json_path) and os.path.isfile(labels_json_path) and os.path.isfile(occ_png_path)):
        print(f"[缺文件] {scene_name} 缺少必要文件")
        continue

    # ====== 读取元信息 & occupancy ======
    with open(occ_json_path) as f:
        meta = json.load(f)
    scale = meta['scale']
    x_min, y_min = meta['min'][:2]

    occ_img = Image.open(occ_png_path).convert("L")
    occupancy = np.array(occ_img)
    occ_h, occ_w = occupancy.shape
    h, w = occ_h, occ_w  # 图像尺寸

    # ====== 检测墙体像素值 ======
    pixels, counts = np.unique(occupancy.reshape(-1), return_counts=True)
    candidate_walls = [int(p) for p in pixels if 0 < p < 250]
    if candidate_walls:
        wall_value = int(candidate_walls[np.argmax([counts[np.where(pixels == v)[0][0]] for v in candidate_walls])])
    else:
        wall_value = int(pixels[0])

    # ====== 读取3DGS物体 labels ======
    with open(labels_json_path) as f:
        labels = json.load(f)

    # ====== 类别 ID 分配 ======
    predefined_classes = [
        "door", "window", "chair", "table", "sofa", "bed", "wardrobe", "plant",
        "floor", "wall", "ceiling"
    ]
    label2id = {cls: idx + 1 for idx, cls in enumerate(predefined_classes)}
    cur_max_id = len(label2id) + 1
    for obj in labels:
        lbl = obj['label']
        if lbl not in label2id:
            label2id[lbl] = cur_max_id
            cur_max_id += 1

    # ====== 统计默认天花高度（墙最小高度固定为0） ======
    ceiling_maxs = []
    wall_heights = []
    wall_max_zs = []

    for obj in labels:
        if "bounding_box" not in obj:
            continue
        zs = [v['z'] for v in obj['bounding_box']]
        obj_min_z = min(zs)
        obj_max_z = max(zs)
        obj_h = obj_max_z - obj_min_z
        if obj['label'].lower() == "ceiling":
            ceiling_maxs.append(obj_max_z)
        elif obj['label'].lower() == "wall":
            wall_heights.append(obj_h)
            wall_max_zs.append(obj_max_z)

    def median_safe(lst, default=None):
        if not lst:
            return default
        arr = np.array(lst, dtype=float)
        return float(np.median(arr))

    # 优先用 ceiling 的中位数作为默认墙体 max_z；否则用 wall 的 max_z 中位数；
    # 再否则用 wall 高度中位数（最小为0）；最后回退为3.0
    if ceiling_maxs:
        default_wall_max_z = median_safe(ceiling_maxs, 3.0)
    elif wall_max_zs:
        default_wall_max_z = median_safe(wall_max_zs, 3.0)
    elif wall_heights:
        default_wall_max_z = median_safe(wall_heights, 3.0)
    else:
        default_wall_max_z = 3.0

    # ====== 绘制语义 mask & 结果列表 ======
    visual_map = np.zeros((h, w), dtype=np.int32)
    result_list = []

    # ====== 加入 item_id 的计数器 ======
    item_counters = defaultdict(int)

    # ==== 普通物体 ====
    for obj in labels:
        if "bounding_box" not in obj:
            continue
        label = obj['label']
        cat_id = label2id[label]
        poly3d = obj['bounding_box']

        # 高度信息（对于 wall 强制 min_z=0，max_z=height，保持高度不变）
        z_values = [v['z'] for v in poly3d]
        min_z_raw = min(z_values)
        max_z_raw = max(z_values)
        height_raw = max_z_raw - min_z_raw

        if label.lower() == "wall":
            out_min_z = 0.0
            out_max_z = height_raw  # 让max_z等于高度，保持高度值不变同时最低为0
        else:
            out_min_z = min_z_raw
            out_max_z = max_z_raw

        height_out = out_max_z - out_min_z  # 对wall等于height_raw；其他同原始高度

        poly2d = [[v['x'], v['y']] for v in poly3d[:4]]
        poly = Polygon(poly2d)
        xys = np.array(poly2d)
        min_x_pixel = int(np.floor((np.min(xys[:, 0]) - x_min) / scale))
        max_x_pixel = int(np.floor((np.max(xys[:, 0]) - x_min) / scale))
        min_y_pixel = int(np.floor((np.min(xys[:, 1]) - y_min) / scale))
        max_y_pixel = int(np.floor((np.max(xys[:, 1]) - y_min) / scale))

        min_x_pixel = np.clip(min_x_pixel, 0, w - 1)
        max_x_pixel = np.clip(max_x_pixel, 0, w - 1)
        min_y_pixel = np.clip(min_y_pixel, 0, h - 1)
        max_y_pixel = np.clip(max_y_pixel, 0, h - 1)

        mask = np.zeros((h, w), dtype=bool)
        for j in range(min_x_pixel, max_x_pixel + 1):
            for i in range(min_y_pixel, max_y_pixel + 1):
                i_flip = h - 1 - i
                j_flip = w - 1 - j
                cx = x_min + (j + 0.5) * scale
                cy = y_min + (i + 0.5) * scale
                if poly.covers(Point(cx, cy)):
                    mask[i_flip, j_flip] = True
                    visual_map[i_flip, j_flip] = cat_id

        ys, xs = np.where(mask)
        if xs.size == 0:
            continue

        xmin_pix, xmax_pix = xs.min(), xs.max()
        ymin_pix, ymax_pix = ys.min(), ys.max()
        x_left = x_min + xmin_pix * scale
        x_right = x_min + (xmax_pix + 1) * scale
        y_bottom = y_min + ymin_pix * scale
        y_top = y_min + (ymax_pix + 1) * scale
        w_box = x_right - x_left
        h_box = y_top - y_bottom

        bbox_m = [format2(x_left), format2(y_bottom), format2(x_right), format2(y_top)]
        bbox_xywh_m = [format2(x_left), format2(y_bottom), format2(w_box), format2(h_box)]
        mask_coords_m = [
            [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
            for y, x in zip(ys, xs)
        ]

        # 生成 item_id
        item_counters[label] += 1
        item_id = f"{label}_{item_counters[label]}"

        result_list.append({
            "category_id": int(cat_id),
            "category_label": label,
            "instance_id": obj.get('ins_id', ''),
            "item_id": item_id,
            "bbox_m": bbox_m,
            "bbox_xywh_m": bbox_xywh_m,
            "area": int(mask.sum()),
            "height_m": format2(height_out),
            "min_z_m": format2(out_min_z),
            "max_z_m": format2(out_max_z),
            "mask_coords_m": mask_coords_m,
        })

    # ==== 增加墙体（来自 occupancy） ====
    wall_cat_id = label2id["wall"]
    wall_mask = (occupancy == wall_value)
    wall_mask_flip = np.flipud(wall_mask)
    visual_map[wall_mask_flip] = wall_cat_id

    wall_label_mask, wall_count = nd_label(wall_mask_flip, structure=np.ones((3, 3), dtype=np.int32))
    for idx in range(1, wall_count + 1):
        block_mask = (wall_label_mask == idx)
        ys, xs = np.where(block_mask)
        if xs.size == 0:
            continue

        xmin_pix, xmax_pix = xs.min(), xs.max()
        ymin_pix, ymax_pix = ys.min(), ys.max()
        x_left = x_min + xmin_pix * scale
        x_right = x_min + (xmax_pix + 1) * scale
        y_bottom = y_min + ymin_pix * scale
        y_top = y_min + (ymax_pix + 1) * scale
        w_box = x_right - x_left
        h_box = y_top - y_bottom

        bbox_m = [format2(x_left), format2(y_bottom), format2(x_right), format2(y_top)]
        bbox_xywh_m = [format2(x_left), format2(y_bottom), format2(w_box), format2(h_box)]
        mask_coords_m = [
            [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
            for y, x in zip(ys, xs)
        ]

        label = "wall"
        item_counters[label] += 1
        item_id = f"{label}_{item_counters[label]}"

        wall_min_z_default = 0.0  # 强制墙体最低为地面
        wall_max_z_default = float(default_wall_max_z)
        wall_height_default = wall_max_z_default - wall_min_z_default

        result_list.append({
            "category_id": int(wall_cat_id),
            "category_label": label,
            "instance_id": f"wall_{idx}",
            "item_id": item_id,
            "bbox_m": bbox_m,
            "bbox_xywh_m": bbox_xywh_m,
            "area": int(block_mask.sum()),
            "height_m": format2(wall_height_default),
            "min_z_m": format2(wall_min_z_default),
            "max_z_m": format2(wall_max_z_default),
            "mask_coords_m": mask_coords_m,
        })

    # ====== 基于 occupancy 添加 Unable Area ======
    map_mask = np.zeros((h, w), dtype=np.uint8)
    for inst in result_list:
        if inst['category_label'].lower() != "door":
            for yy, xx in inst['mask_coords_m']:
                py = int(np.floor((float(yy) - y_min) / scale))
                px = int(np.floor((float(xx) - x_min) / scale))
                if 0 <= py < h and 0 <= px < w:
                    map_mask[py, px] = 1

    map_walkable = (map_mask == 0)
    occ_walkable_aligned = np.zeros((h, w), dtype=bool)
    occ_min_x, occ_min_y = meta['min'][:2]
    for py in range(h):
        for px in range(w):
            wx = x_min + px * scale
            wy = y_min + py * scale
            occ_px = int(np.floor((wx - occ_min_x) / scale))
            occ_py = int(np.floor((wy - occ_min_y) / scale))
            if 0 <= occ_px < occ_w and 0 <= occ_py < occ_h:
                occ_val = occupancy[occ_h - 1 - occ_py, occ_px]
                occ_walkable_aligned[py, px] = (occ_val == 255)

    unable_mask = np.logical_and(map_walkable, ~occ_walkable_aligned)
    labeled, num = nd_label(unable_mask, structure=np.ones((3, 3)))
    for idx in range(1, num + 1):
        block = (labeled == idx)
        area = block.sum()
        if area < 5:
            continue
        ys, xs = np.where(block)
        xmin_pix, xmax_pix = xs.min(), xs.max()
        ymin_pix, ymax_pix = ys.min(), ys.max()
        x_left = x_min + xmin_pix * scale
        x_right = x_min + (xmax_pix + 1) * scale
        y_bottom = y_min + ymin_pix * scale
        y_top = y_min + (ymax_pix + 1) * scale
        w_box = x_right - x_left
        h_box = y_top - y_bottom
        mask_coords_m = [
            [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
            for y, x in zip(ys, xs)
        ]
        label = "Unable Area"
        item_counters[label] += 1
        item_id = f"{label}_{item_counters[label]}"
        result_list.append({
            "category_id": -1,
            "category_label": label,
            "instance_id": f"unable_area_{idx}",
            "item_id": item_id,
            "bbox_m": [format2(x_left), format2(y_bottom), format2(x_right), format2(y_top)],
            "bbox_xywh_m": [format2(x_left), format2(y_bottom), format2(w_box), format2(h_box)],
            "area": int(area),
            "height_m": format2(0.0),
            "min_z_m": format2(0.0),
            "max_z_m": format2(0.0),
            "mask_coords_m": mask_coords_m,
        })

    # ====== 保存 JSON ======
    with open(out_json, 'w') as f:
        json.dump(result_list, f, indent=2)
    print(f"[保存] {out_json}")

print("✅ 批量处理完成!")