import numpy as np
import json
from PIL import Image
import matplotlib.pyplot as plt
from shapely.geometry import Polygon, Point

# 1. 读取Occupancy map和参数
with open('test/839920/occupancy.json') as f:
    meta = json.load(f)
scale = meta['scale']
x_min, y_min = meta['min'][:2]

occ_img = Image.open('test/839920/occupancy.png')
occupancy = np.array(occ_img)
h, w = occupancy.shape

# 创建语义map，0为无标签
semantic_map = np.zeros((h, w), dtype=np.int32)  # 或用str/object类型

# 2. 读取json的所有物体
with open('test/839920/labels.json') as f:
    labels = json.load(f)

# 建立label->数字映射，用于语义图整数化（或直接存str）
label2id = {}
label_cur = 1

for obj in labels:
    if "bounding_box" not in obj: continue
    label = obj['label']
    if label not in label2id:
        label2id[label] = label_cur
        label_cur += 1
    poly3d = obj['bounding_box']  # 8点
    # 只要z最大那四点或z最小那四点(建议选底面或顶面)
    poly2d = [[v['x'], v['y']] for v in poly3d[:4]]
    poly = Polygon(poly2d)  # shapely多边形

    # 计算覆盖的像素范围
    xys = np.array(poly2d)
    min_x_pixel = int(np.floor((np.min(xys[:, 0]) - x_min) / scale))
    max_x_pixel = int(np.ceil((np.max(xys[:, 0]) - x_min) / scale))
    min_y_pixel = int(np.floor((np.min(xys[:, 1]) - y_min) / scale))
    max_y_pixel = int(np.ceil((np.max(xys[:, 1]) - y_min) / scale))

    min_x_pixel = np.clip(min_x_pixel, 0, w - 1)
    max_x_pixel = np.clip(max_x_pixel, 0, w - 1)
    min_y_pixel = np.clip(min_y_pixel, 0, h - 1)
    max_y_pixel = np.clip(max_y_pixel, 0, h - 1)

    # 对应区域的每个像素point判断是否在polygon内部
    for i in range(min_x_pixel, max_x_pixel + 1):
        for j in range(min_y_pixel, max_y_pixel + 1):
            cx = x_min + i * scale + scale / 2
            cy = y_min + j * scale + scale / 2
            if poly.contains(Point(cx, cy)):
                semantic_map[j, i] = label2id[label]  # 注意j为y，i为x

# 可视化/保存
plt.imshow(semantic_map)
plt.show()
np.save("semantic_map.npy", semantic_map)


with open('semantic_label2id.json', 'w') as f:
    json.dump(label2id, f, indent=2)
# 反查id2label也可一起保存
with open('semantic_id2label.json', 'w') as f:
    id2label = {v: k for k, v in label2id.items()}
    json.dump(id2label, f, indent=2)