import os
import json
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import label as nd_label
from PIL import Image

scene_dir = 'test/839919'

# ====== 读取 occupancy 元信息 ======
with open(os.path.join(scene_dir, 'occupancy.json')) as f:
    meta = json.load(f)

scale = meta['scale']
occ_min_x, occ_min_y = meta['min'][:2]

occ_img = Image.open(os.path.join(scene_dir, 'occupancy.png')).convert("L")
occupancy = np.array(occ_img)
occ_h, occ_w = occupancy.shape

# ====== 读取保存好的语义 Map JSON ======
json_path = os.path.join(scene_dir, '2D_Semantic_Map_839919_with_Walls.json')
with open(json_path) as f:
    data = json.load(f)

# ====== 从 JSON 推断世界范围 ======
all_y = [float(y) for inst in data for y, x in inst['mask_coords_m']]
all_x = [float(x) for inst in data for y, x in inst['mask_coords_m']]
min_y, max_y = min(all_y), max(all_y)
min_x, max_x = min(all_x), max(all_x)

# ====== 分辨率（行列像素数） ======
h = int(np.ceil((max_y - min_y) / scale)) + 1
w = int(np.ceil((max_x - min_x) / scale)) + 1
print(f"统一网格: min_x={min_x}, min_y={min_y}, w={w}, h={h}, scale={scale}")

# ====== 世界坐标 → 统一网格像素 ======
def world2pix_map(x, y):
    px = int(np.floor((float(x) - min_x) / scale))
    py = int(np.floor((float(y) - min_y) / scale))
    return py, px

# ====== 从 JSON 构建 Map 栅格（跳过 door 类，使其可走） ======
map_mask = np.zeros((h, w), dtype=np.uint8)
for inst in data:
    label = inst['category_label'].lower()
    for y, x in inst['mask_coords_m']:
        py, px = world2pix_map(x, y)
        if 0 <= py < h and 0 <= px < w:
            if label != "door":
                map_mask[py, px] = 1
map_walkable = (map_mask == 0)  # True=可走，False=障碍

# ====== 构建对齐后的 occupancy 栅格 ======
occ_walkable_aligned = np.zeros((h, w), dtype=bool)
for py in range(h):
    for px in range(w):
        wx = min_x + px * scale
        wy = min_y + py * scale
        occ_px = int(np.floor((wx - occ_min_x) / scale))
        occ_py = int(np.floor((wy - occ_min_y) / scale))
        if 0 <= occ_px < occ_w and 0 <= occ_py < occ_h:
            # 翻转 occupancy 的 y 方向以匹配 origin='lower'
            occ_val = occupancy[occ_h - 1 - occ_py, occ_px]
            occ_walkable_aligned[py, px] = (occ_val == 255)

# ====== 差异检测：Map 可走 & Occ 不可走 = Unable Area 区域 ======
unable_mask = np.logical_and(map_walkable, ~occ_walkable_aligned)

# ====== 连通域分析（去噪） ======
labeled, num = nd_label(unable_mask, structure=np.ones((3,3)))
print(f"检测到 Unable 区域连通块: {num}")

new_instances = []
instance_id_start = len(data) + 1

for idx in range(1, num+1):
    block = (labeled == idx)
    area = block.sum()
    if area < 5:  # 小面积忽略
        continue
    ys, xs = np.where(block)
    xmin_pix, xmax_pix = xs.min(), xs.max()
    ymin_pix, ymax_pix = ys.min(), ys.max()
    # 世界坐标
    x_left = min_x + xmin_pix * scale
    x_right = min_x + (xmax_pix + 1) * scale
    y_bottom = min_y + ymin_pix * scale
    y_top = min_y + (ymax_pix + 1) * scale
    w_box = x_right - x_left
    h_box = y_top - y_bottom
    # mask 坐标（世界系）
    mask_coords_m = [
        [f"{min_y + (y + 0.5) * scale:.2f}", f"{min_x + (x + 0.5) * scale:.2f}"]
        for y, x in zip(ys, xs)
    ]
    new_instances.append({
        "category_id": -1,  # 特殊ID，可以后续映射
        "category_label": "Unable Area",
        "instance_id": f"unable_area_{idx}",
        "bbox_m": [f"{x_left:.2f}", f"{y_bottom:.2f}", f"{x_right:.2f}", f"{y_top:.2f}"],
        "bbox_xywh_m": [f"{x_left:.2f}", f"{y_bottom:.2f}", f"{w_box:.2f}", f"{h_box:.2f}"],
        "area": int(area),
        "mask_coords_m": mask_coords_m
    })

# ====== 合并到原 JSON 数据 ======
merged_data = data + new_instances

# ====== 保存新 JSON ======
out_json_path = os.path.join(scene_dir, '2D_Semantic_Map_839919_Merged_Unable.json')
with open(out_json_path, 'w') as f:
    json.dump(merged_data, f, indent=2)
print(f"已保存合并结果: {out_json_path}")

# ====== 可视化：底图 + Unable Area 覆盖 ======
extent = [min_x, min_x + w*scale, min_y, min_y + h*scale]
plt.figure(figsize=(12, 12))
# 底图
plt.imshow(map_walkable, cmap='gray', origin='lower', extent=extent)
# Unable 区域覆盖
overlay = np.zeros((h, w, 4))
overlay[unable_mask] = [1, 0, 0, 0.5]  # 红色半透明
plt.imshow(overlay, origin='lower', extent=extent)
plt.title('Unable Area merged into JSON')
plt.xlabel('X (meters)')
plt.ylabel('Y (meters)')
plt.savefig(os.path.join(scene_dir, '2D_Semantic_Map_839919_Merged_Unable.png'), bbox_inches='tight', dpi=300)
plt.show()