import os
import json
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from shapely.geometry import Polygon, Point
from scipy.ndimage import label as nd_label
from collections import defaultdict

scene_dir = '/home/sig/sig/qianluo/qianluo/3DGS_VLN_Benchmark/test/839873'
scene_name = os.path.basename(scene_dir)

def format2(x):
    return "{:.2f}".format(float(x))

# ====== 读取元信息 & occupancy ======
with open(os.path.join(scene_dir, 'occupancy.json')) as f:
    meta = json.load(f)
scale = meta['scale']
x_min, y_min = meta['min'][:2]

occ_img = Image.open(os.path.join(scene_dir, 'occupancy.png')).convert("L")
occupancy = np.array(occ_img)
occ_h, occ_w = occupancy.shape
h, w = occ_h, occ_w

# ====== 检测墙体像素值 ======
pixels, counts = np.unique(occupancy.reshape(-1), return_counts=True)
candidate_walls = [int(p) for p in pixels if 0 < p < 250]
if candidate_walls:
    wall_value = int(candidate_walls[np.argmax([counts[np.where(pixels == v)[0][0]] for v in candidate_walls])])
else:
    wall_value = int(pixels[0])
print(f"[墙类别确认] wall_value = {wall_value}")

# ====== 读取3DGS物体 labels ======
with open(os.path.join(scene_dir, 'labels.json')) as f:
    labels = json.load(f)

# ====== 类别 ID 预分配 ======
predefined_classes = [
    "door", "window", "chair", "table", "sofa", "bed", "wardrobe", "plant",
    "floor", "wall", "ceiling"
]
label2id = {cls: idx + 1 for idx, cls in enumerate(predefined_classes)}
id2label = {v: k for k, v in label2id.items()}
cur_max_id = len(label2id) + 1

for obj in labels:
    lbl = obj['label']
    if lbl not in label2id:
        label2id[lbl] = cur_max_id
        id2label[cur_max_id] = lbl
        cur_max_id += 1

# ====== 构建 visual_map 和数据列表 ======
visual_map = np.zeros((h, w), dtype=np.int32)
result_list = []
item_counters = defaultdict(int)

# ==== 普通物体 ====
for obj in labels:
    if "bounding_box" not in obj:
        continue
    label = obj['label']
    cat_id = label2id[label]
    poly3d = obj['bounding_box']

    # 计算高度信息
    z_values = [v['z'] for v in poly3d]
    min_z = min(z_values)
    max_z = max(z_values)
    height = max_z - min_z

    poly2d = [[v['x'], v['y']] for v in poly3d[:4]]
    poly = Polygon(poly2d)
    xys = np.array(poly2d)

    min_x_pixel = int(np.floor((np.min(xys[:, 0]) - x_min) / scale))
    max_x_pixel = int(np.floor((np.max(xys[:, 0]) - x_min) / scale))
    min_y_pixel = int(np.floor((np.min(xys[:, 1]) - y_min) / scale))
    max_y_pixel = int(np.floor((np.max(xys[:, 1]) - y_min) / scale))

    min_x_pixel = np.clip(min_x_pixel, 0, w - 1)
    max_x_pixel = np.clip(max_x_pixel, 0, w - 1)
    min_y_pixel = np.clip(min_y_pixel, 0, h - 1)
    max_y_pixel = np.clip(max_y_pixel, 0, h - 1)

    mask = np.zeros((h, w), dtype=bool)
    for j in range(min_x_pixel, max_x_pixel + 1):
        for i in range(min_y_pixel, max_y_pixel + 1):
            i_flip = h - 1 - i
            j_flip = w - 1 - j
            cx = x_min + (j + 0.5) * scale
            cy = y_min + (i + 0.5) * scale
            if poly.covers(Point(cx, cy)):
                mask[i_flip, j_flip] = True
                visual_map[i_flip, j_flip] = cat_id

    ys, xs = np.where(mask)
    if xs.size == 0:
        continue

    xmin_pix, xmax_pix = xs.min(), xs.max()
    ymin_pix, ymax_pix = ys.min(), ys.max()
    x_left = x_min + xmin_pix * scale
    x_right = x_min + (xmax_pix + 1) * scale
    y_bottom = y_min + ymin_pix * scale
    y_top = y_min + (ymax_pix + 1) * scale
    w_box = x_right - x_left
    h_box = y_top - y_bottom

    bbox_m = [format2(x_left), format2(y_bottom), format2(x_right), format2(y_top)]
    bbox_xywh_m = [format2(x_left), format2(y_bottom), format2(w_box), format2(h_box)]

    mask_coords_m = [
        [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
        for y, x in zip(ys, xs)
    ]

    item_counters[label] += 1
    item_id = f"{label}_{item_counters[label]}"
    result_list.append({
        "category_id": int(cat_id),
        "category_label": label,
        "instance_id": obj.get('ins_id', ''),
        "item_id": item_id,
        "bbox_m": bbox_m,
        "bbox_xywh_m": bbox_xywh_m,
        "area": int(mask.sum()),
        "height_m": format2(height),
        "min_z_m": format2(min_z),
        "max_z_m": format2(max_z),
        "mask_coords_m": mask_coords_m,
    })

# ==== 增加墙体 ====
wall_cat_id = label2id["wall"]
wall_mask = (occupancy == wall_value)
wall_mask_flip = np.flipud(wall_mask)
visual_map[wall_mask_flip] = wall_cat_id

structure = np.ones((3, 3), dtype=np.int32)
wall_label_mask, wall_count = nd_label(wall_mask_flip, structure=structure)

for idx in range(1, wall_count + 1):
    block_mask = (wall_label_mask == idx)
    ys, xs = np.where(block_mask)
    if xs.size == 0 or ys.size == 0:
        continue
    xmin_pix, xmax_pix = xs.min(), xs.max()
    ymin_pix, ymax_pix = ys.min(), ys.max()
    x_left = x_min + xmin_pix * scale
    x_right = x_min + (xmax_pix + 1) * scale
    y_bottom = y_min + ymin_pix * scale
    y_top = y_min + (ymax_pix + 1) * scale
    w_box = x_right - x_left
    h_box = y_top - y_bottom
    bbox_m = [format2(x_left), format2(y_bottom), format2(x_right), format2(y_top)]
    bbox_xywh_m = [format2(x_left), format2(y_bottom), format2(w_box), format2(h_box)]
    mask_coords_m = [
        [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
        for y, x in zip(ys, xs)
    ]
    label = "wall"
    item_counters[label] += 1
    item_id = f"{label}_{item_counters[label]}"
    result_list.append({
        "category_id": int(wall_cat_id),
        "category_label": label,
        "instance_id": f"wall_{idx}",
        "item_id": item_id,
        "bbox_m": bbox_m,
        "bbox_xywh_m": bbox_xywh_m,
        "area": int(block_mask.sum()),
        "height_m": format2(3.0),  # 墙高度可视情况固定
        "min_z_m": format2(0.0),
        "max_z_m": format2(3.0),
        "mask_coords_m": mask_coords_m,
    })

# ====== 基于 occupancy 添加 Unable Area ======
map_mask = np.zeros((h, w), dtype=np.uint8)
for inst in result_list:
    if inst['category_label'].lower() != "door":
        for yy, xx in inst['mask_coords_m']:
            py = int(np.floor((float(yy) - y_min) / scale))
            px = int(np.floor((float(xx) - x_min) / scale))
            if 0 <= py < h and 0 <= px < w:
                map_mask[py, px] = 1

map_walkable = (map_mask == 0)

# ====== occupancy 可走区域对齐 ======
occ_walkable_aligned = np.zeros((h, w), dtype=bool)
occ_min_x, occ_min_y = meta['min'][:2]
for py in range(h):
    for px in range(w):
        wx = x_min + px * scale
        wy = y_min + py * scale
        occ_px = int(np.floor((wx - occ_min_x) / scale))
        occ_py = int(np.floor((wy - occ_min_y) / scale))
        if 0 <= occ_px < occ_w and 0 <= occ_py < occ_h:
            occ_val = occupancy[occ_h - 1 - occ_py, occ_px]
            occ_walkable_aligned[py, px] = (occ_val == 255)

# ====== 差异区域 ======
unable_mask = np.logical_and(map_walkable, ~occ_walkable_aligned)
labeled, num = nd_label(unable_mask, structure=np.ones((3, 3)))
print(f"检测到 Unable Area 分块: {num}")

for idx in range(1, num + 1):
    block = (labeled == idx)
    area = block.sum()
    if area < 5:
        continue
    ys, xs = np.where(block)
    xmin_pix, xmax_pix = xs.min(), xs.max()
    ymin_pix, ymax_pix = ys.min(), ys.max()
    x_left = x_min + xmin_pix * scale
    x_right = x_min + (xmax_pix + 1) * scale
    y_bottom = y_min + ymin_pix * scale
    y_top = y_min + (ymax_pix + 1) * scale
    w_box = x_right - x_left
    h_box = y_top - y_bottom
    mask_coords_m = [
        [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
        for y, x in zip(ys, xs)
    ]
    label = "Unable Area"
    item_counters[label] += 1
    item_id = f"{label}_{item_counters[label]}"
    result_list.append({
        "category_id": -1,
        "category_label": label,
        "instance_id": f"unable_area_{idx}",
        "item_id": item_id,
        "bbox_m": [format2(x_left), format2(y_bottom), format2(x_right), format2(y_top)],
        "bbox_xywh_m": [format2(x_left), format2(y_bottom), format2(w_box), format2(h_box)],
        "area": int(area),
        "height_m": format2(0.0),
        "min_z_m": format2(0.0),
        "max_z_m": format2(0.0),
        "mask_coords_m": mask_coords_m,
    })

# ====== 保存一体化 JSON ======
out_json = os.path.join(scene_dir, f'2D_Semantic_Map_{scene_name}_Complete_v3.json')
with open(out_json, 'w') as f:
    json.dump(result_list, f, indent=2)
print(f"已保存: {out_json}")

# ====== 可视化 ======
extent = [float(x_min), float(x_min) + w * scale, float(y_min), float(y_min) + h * scale]
plt.figure(figsize=(12, 12))
plt.imshow(map_walkable, cmap='gray', origin='lower', extent=extent)
overlay = np.zeros((h, w, 4))
overlay[unable_mask] = [1, 0, 0, 0.5]
plt.imshow(overlay, origin='lower', extent=extent)
plt.title('Map Walkable + Unable Area Overlay')
plt.xlabel("X (meters)")
plt.ylabel("Y (meters)")
plt.savefig(os.path.join(scene_dir, f'2D_Semantic_Map_{scene_name}_Complete_v3.png'),
            bbox_inches='tight', dpi=300)
plt.show()