import numpy as np
import json
from PIL import Image
import matplotlib.pyplot as plt
from shapely.geometry import Polygon, Point
from skimage.measure import label, regionprops

# ====== 1. 读取Occupancy map参数 ======
with open('test/839920/occupancy.json') as f:
    meta = json.load(f)
scale = meta['scale']
x_min, y_min = meta['min'][:2]

occ_img = Image.open('test/839920/occupancy.png')
occupancy = np.array(occ_img)
h, w = occupancy.shape

# ====== 2. 读取3DGS物体标注 ======
with open('test/839920/labels.json') as f:
    labels = json.load(f)

# ====== 3. 建立label2id、id2label映射 ======
label2id = {}
id2label = {}
label_cur = 1
for obj in labels:
    label = obj['label']
    if label not in label2id:
        label2id[label] = label_cur
        id2label[label_cur] = label
        label_cur += 1

result_list = []
visual_map = np.zeros((h, w), dtype=np.int32)

for obj in labels:
    if "bounding_box" not in obj: continue
    label = obj['label']
    cat_id = label2id[label]
    poly3d = obj['bounding_box']
    poly2d = [[v['x'], v['y']] for v in poly3d[:4]]
    poly = Polygon(poly2d)
    xys = np.array(poly2d)
    min_x_pixel = int(np.floor((np.min(xys[:, 0]) - x_min) / scale))
    max_x_pixel = int(np.ceil((np.max(xys[:, 0]) - x_min) / scale))
    min_y_pixel = int(np.floor((np.min(xys[:, 1]) - y_min) / scale))
    max_y_pixel = int(np.ceil((np.max(xys[:, 1]) - y_min) / scale))
    min_x_pixel = np.clip(min_x_pixel, 0, w - 1)
    max_x_pixel = np.clip(max_x_pixel, 0, w - 1)
    min_y_pixel = np.clip(min_y_pixel, 0, h - 1)
    max_y_pixel = np.clip(max_y_pixel, 0, h - 1)
    mask = np.zeros((h, w), dtype=bool)
    for i in range(min_x_pixel, max_x_pixel + 1):
        i_flip = w - 1 - i  # 左右镜像
        for j in range(min_y_pixel, max_y_pixel + 1):
            cx = x_min + i * scale + scale / 2
            cy = y_min + j * scale + scale / 2
            if poly.contains(Point(cx, cy)):
                mask[j, i_flip] = True
                visual_map[j, i_flip] = cat_id
    ys, xs = np.where(mask)
    xs_flip = [w - 1 - x for x in xs]  # mirror for coords
    if xs.size == 0 or ys.size == 0:
        continue
    xmin, xmax = min(xs_flip), max(xs_flip)
    ymin, ymax = ys.min(), ys.max()
    bbox = [int(xmin), int(ymin), int(xmax), int(ymax)]
    bbox_xywh = [int(xmin), int(ymin), int(xmax - xmin), int(ymax - ymin)]
    mask_coords = [[int(y), int(xf)] for y, xf in zip(ys, xs_flip)]
    result_list.append({
        "category_id": int(cat_id),
        "category_label": label,
        "instance_id": obj.get('ins_id', ''),
        "bbox": bbox,
        "bbox_xywh": bbox_xywh,
        "area": int(mask.sum()),
        "mask_coords": mask_coords
    })


# 可视化
plt.figure(figsize=(12,12))
plt.imshow(visual_map, cmap='tab20')
# 添加bbox
# for inst in result_list:
#     xmin, ymin, xmax, ymax = inst['bbox']
    # plt.gca().add_patch(
    #     plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
    #                   edgecolor='red', fill=False, linewidth=1)
    # )
plt.title('2D Semantic Map')
plt.savefig('2D_Semantic_Map_839920.png', bbox_inches='tight', dpi=300)
plt.show()

# 输出2D json
with open('2D_Semantic_Map_839920.json', 'w') as f:
    json.dump(result_list, f, indent=2)
print(f"写出语义物体数：{len(result_list)}")