import numpy as np
import json
import os
from PIL import Image
import matplotlib.pyplot as plt
from shapely.geometry import Polygon, Point
from matplotlib import cm
from matplotlib.colors import ListedColormap
from scipy.ndimage import label as nd_label

scene_dir = '/home/sig/sig/qianluo/qianluo/3DGS_VLN_Benchmark/test/839873'

def format2(x):
    return "{:.2f}".format(float(x))

# ====== 读取元信息 ======
with open(os.path.join(scene_dir, 'occupancy.json')) as f:
    meta = json.load(f)
scale = meta['scale']
x_min, y_min = meta['min'][:2]

occ_img = Image.open(os.path.join(scene_dir, 'occupancy.png')).convert("L")
occupancy = np.array(occ_img)
h, w = occupancy.shape

# ====== 检测墙体像素值 ======
pixels, counts = np.unique(occupancy.reshape(-1), return_counts=True)
candidate_walls = [int(p) for p in pixels if 0 < p < 250]
if candidate_walls:
    wall_value = int(candidate_walls[np.argmax([counts[np.where(pixels == v)[0][0]] for v in candidate_walls])])
else:
    wall_value = int(pixels[0])
print(f"[墙类别确认] wall_value = {wall_value}")
shift = 1

# ====== 读取3DGS物体 ======
with open(os.path.join(scene_dir, 'labels.json')) as f:
    labels = json.load(f)

# ====== 类别 ID 分配（固定顺序） ======
predefined_classes = [
    "door", "window", "chair", "table", "sofa", "bed", "wardrobe", "plant",
    "floor", "wall", "ceiling"
]
label2id = {cls: idx+1 for idx, cls in enumerate(predefined_classes)}
id2label = {v: k for k, v in label2id.items()}
cur_max_id = len(label2id) + 1
for obj in labels:
    lbl = obj['label']
    if lbl not in label2id:
        label2id[lbl] = cur_max_id
        id2label[cur_max_id] = lbl
        cur_max_id += 1

# ====== 绘制类别 mask ======
visual_map = np.zeros((h, w), dtype=np.int32)  # 0=背景
result_list = []

# ==== 其他物体 ====
for obj in labels:
    if "bounding_box" not in obj:
        continue
    label = obj['label']
    cat_id = label2id[label]  # 类别 ID
    poly3d = obj['bounding_box']
    poly2d = [[v['x'], v['y']] for v in poly3d[:4]]
    poly = Polygon(poly2d)
    xys = np.array(poly2d)
    min_x_pixel = int(np.floor((np.min(xys[:, 0]) - x_min) / scale))
    max_x_pixel = int(np.floor((np.max(xys[:, 0]) - x_min) / scale))
    min_y_pixel = int(np.floor((np.min(xys[:, 1]) - y_min) / scale))
    max_y_pixel = int(np.floor((np.max(xys[:, 1]) - y_min) / scale))
    min_x_pixel = np.clip(min_x_pixel, 0, w-1)
    max_x_pixel = np.clip(max_x_pixel, 0, w-1)
    min_y_pixel = np.clip(min_y_pixel, 0, h-1)
    max_y_pixel = np.clip(max_y_pixel, 0, h-1)

    mask = np.zeros((h, w), dtype=bool)
    for j in range(min_x_pixel, max_x_pixel+1):
        for i in range(min_y_pixel, max_y_pixel+1):
            i_flip = h-1-i
            j_flip = w-1-j
            cx = x_min + (j + 0.5) * scale
            cy = y_min + (i + 0.5) * scale
            if poly.covers(Point(cx, cy)):
                mask[i_flip, j_flip] = True
                visual_map[i_flip, j_flip] = cat_id

    ys, xs = np.where(mask)
    if xs.size == 0:
        continue
    xmin_pix, xmax_pix = xs.min(), xs.max()
    ymin_pix, ymax_pix = ys.min(), ys.max()
    x_left = x_min + xmin_pix * scale
    x_right = x_min + (xmax_pix + 1) * scale
    y_bottom = y_min + ymin_pix * scale
    y_top = y_min + (ymax_pix + 1) * scale
    w_box = x_right - x_left
    h_box = y_top - y_bottom
    bbox_m = [format2(x_left), format2(y_bottom), format2(x_right), format2(y_top)]
    bbox_xywh_m = [format2(x_left), format2(y_bottom), format2(w_box), format2(h_box)]
    mask_coords_m = [
        [format2(y_min + (y + 0.5)*scale), format2(x_min + (x + 0.5)*scale)]
        for y, x in zip(ys, xs)
    ]
    result_list.append({
        "category_id": int(cat_id),
        "category_label": label,
        "instance_id": obj.get('ins_id', ''),
        "bbox_m": bbox_m,
        "bbox_xywh_m": bbox_xywh_m,
        "area": int(mask.sum()),
        "mask_coords_m": mask_coords_m
    })

# ==== 增加墙体 mask + bbox ====
wall_cat_id = label2id["wall"]
wall_mask = (occupancy == wall_value)
wall_mask_flip = np.flipud(wall_mask)
if shift > 0:
    wall_mask_flip[:, :-shift] = wall_mask_flip[:, shift:]
    wall_mask_flip[:, -shift:] = 0
visual_map[wall_mask_flip] = wall_cat_id

structure = np.ones((3, 3), dtype=np.int32)
wall_label_mask, wall_count = nd_label(wall_mask_flip, structure=structure)
print("墙体连通分块数量:", wall_count)

# 为每块墙体生成 bbox 信息
for idx in range(1, wall_count+1):
    block_mask = (wall_label_mask == idx)
    ys, xs = np.where(block_mask)
    if xs.size == 0 or ys.size == 0:
        continue
    xmin_pix, xmax_pix = xs.min(), xs.max()
    ymin_pix, ymax_pix = ys.min(), ys.max()
    x_left = x_min + xmin_pix * scale
    x_right = x_min + (xmax_pix + 1) * scale
    y_bottom = y_min + ymin_pix * scale
    y_top = y_min + (ymax_pix + 1) * scale
    w_box = x_right - x_left
    h_box = y_top - y_bottom
    bbox_m = [format2(x_left), format2(y_bottom), format2(x_right), format2(y_top)]
    bbox_xywh_m = [format2(x_left), format2(y_bottom), format2(w_box), format2(h_box)]
    mask_coords_m = [
        [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
        for y, x in zip(ys, xs)
    ]
    result_list.append({
        "category_id": int(wall_cat_id),
        "category_label": "wall",
        "instance_id": f"wall_{idx}",
        "bbox_m": bbox_m,
        "bbox_xywh_m": bbox_xywh_m,
        "area": int(block_mask.sum()),
        "mask_coords_m": mask_coords_m
    })

# ====== 输出 JSON ======
out_json = os.path.join(scene_dir, '2D_Semantic_Map_839873_with_Walls.json')
with open(out_json, 'w') as f:
    json.dump(result_list, f, indent=2)

# ====== 绘图 ======
extent = [float(x_min), float(x_min)+w*scale, float(y_min), float(y_min)+h*scale]
max_cat_id = int(visual_map.max())

base_colors = cm.get_cmap('tab20', max_cat_id+1).colors
colors = list(base_colors)
colors = [tuple(c[i]*0.75 + 0.25 for i in range(3)) + (1.0,) for c in colors]
wall_rgb = (158/255, 218/255, 229/255, 1.0)  # 浅蓝
colors[wall_cat_id] = wall_rgb

new_cmap = ListedColormap(colors)

plt.figure(figsize=(12, 12))
plt.imshow(visual_map, cmap=new_cmap, extent=extent, origin='lower', vmin=0, vmax=max_cat_id)

# 画 bbox
for inst in result_list:
    bbox = inst['bbox_xywh_m']
    x_left, y_bottom, w_box, h_box = map(float, bbox)
    plt.gca().add_patch(
        plt.Rectangle((x_left, y_bottom), w_box, h_box, edgecolor='red', fill=False, linewidth=1)
    )

plt.xlabel("X (meters)")
plt.ylabel("Y (meters)")
plt.title("2D Semantic Map - Wall #9EDAE5 Fixed + BBox")
plt.savefig(os.path.join(scene_dir, '2D_Semantic_Map_839873_WallFixed.png'), bbox_inches='tight', dpi=300)
plt.show()