import numpy as np
import json
from PIL import Image
import matplotlib.pyplot as plt
from shapely.geometry import Polygon, Point

# ====== 1. 读取Occupancy map参数 ======
with open('test/839920/occupancy.json') as f:
    meta = json.load(f)
scale = meta['scale'] # 分辨率（单位：米）
x_min, y_min = meta['min'][:2]
x_max, y_max = meta['max'][:2]
occ_img = Image.open('test/839920/occupancy.png')
occupancy = np.array(occ_img)
h, w = occupancy.shape

# ====== 2. 读取3DGS物体标注 ======
with open('test/839920/labels.json') as f:
    labels = json.load(f)
# label2id 和 id2label 转换
label2id = {}
id2label = {}
label_cur = 1
for obj in labels:
    label = obj['label']
    if label not in label2id:
        label2id[label] = label_cur
        id2label[label_cur] = label
        label_cur += 1
result_list = []
visual_map = np.zeros((h, w), dtype=np.int32)

def format2(x):
    return "{:.2f}".format(float(x))

for obj in labels:
    if "bounding_box" not in obj: continue
    label = obj['label']
    cat_id = label2id[label]
    poly3d = obj['bounding_box']
    poly2d = [[v['x'], v['y']] for v in poly3d[:4]] # 取3D物体底面
    poly = Polygon(poly2d)
    xys = np.array(poly2d)
    # 像素范围(闭区间)
    min_x_pixel = int(np.floor((np.min(xys[:, 0]) - x_min) / scale))
    max_x_pixel = int(np.floor((np.max(xys[:, 0]) - x_min) / scale))
    min_y_pixel = int(np.floor((np.min(xys[:, 1]) - y_min) / scale))
    max_y_pixel = int(np.floor((np.max(xys[:, 1]) - y_min) / scale))
    min_x_pixel = np.clip(min_x_pixel, 0, w - 1)
    max_x_pixel = np.clip(max_x_pixel, 0, w - 1)
    min_y_pixel = np.clip(min_y_pixel, 0, h - 1)
    max_y_pixel = np.clip(max_y_pixel, 0, h - 1)
    mask = np.zeros((h, w), dtype=bool)
    # 双flip
    for j in range(min_x_pixel, max_x_pixel + 1):  # 保证右边也覆盖到
        for i in range(min_y_pixel, max_y_pixel + 1):
            i_flip = h - 1 - i
            j_flip = w - 1 - j
            # 注意加0.5*scale
            cx = x_min + (j + 0.5) * scale
            cy = y_min + (i + 0.5) * scale
            p = Point(cx, cy)
            if poly.covers(p):  # covers能包含在poly边界上的点
                mask[i_flip, j_flip] = True
                visual_map[i_flip, j_flip] = cat_id
    ys, xs = np.where(mask)
    if xs.size == 0 or ys.size == 0:
        continue
    # bbox像素极值用mask
    xmin_pix, xmax_pix = min(xs), max(xs)
    ymin_pix, ymax_pix = min(ys), max(ys)
    # 物理坐标换算时加0.5（中心）
    x_left = x_min + xmin_pix * scale
    x_right = x_min + (xmax_pix + 1) * scale
    y_bottom = y_min + ymin_pix * scale
    y_top = y_min + (ymax_pix + 1) * scale
    w_box = x_right - x_left
    h_box = y_top - y_bottom
    bbox_m = [
        format2(x_left),
        format2(y_bottom),
        format2(x_right),
        format2(y_top)
    ]
    bbox_xywh_m = [
        format2(x_left),
        format2(y_bottom),
        format2(w_box),
        format2(h_box)
    ]
    # mask_coords_m也是中心
    mask_coords_m = [
        [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
        for y, x in zip(ys, xs)
    ]

    # Debug 检查栅格化后的误差
    # ins_id = obj.get('ins_id', '')
    #
    # def flip_x(x):
    #     return x_min + x_max - x
    #
    # def flip_y(y):
    #     return y_min + y_max - y
    #
    # orig_xs = xys[:, 0]
    # orig_ys = xys[:, 1]
    # flip_xs = flip_x(orig_xs)
    # flip_ys = flip_y(orig_ys)
    # mask_xs = [float(x) for y, x in mask_coords_m]
    # mask_ys = [float(y) for y, x in mask_coords_m]
    # print(
    #     f'[ID {ins_id}] 3D flip x = [{flip_xs.min():.2f}, {flip_xs.max():.2f}], 2D mask x = [{min(mask_xs):.2f}, {max(mask_xs):.2f}]')
    # print(
    #     f'[ID {ins_id}] 3D flip y = [{flip_ys.min():.2f}, {flip_ys.max():.2f}], 2D mask y = [{min(mask_ys):.2f}, {max(mask_ys):.2f}]')


    result_list.append({
        "category_id": int(cat_id),
        "category_label": label,
        "instance_id": obj.get('ins_id', ''),
        "bbox_m": bbox_m,
        "bbox_xywh_m": bbox_xywh_m,
        "area": int(mask.sum()),
        "mask_coords_m": mask_coords_m
    })



# =========== 画物理世界坐标下的图 ===========
extent = [float(x_min), float(x_min) + w * scale, float(y_min), float(y_min) + h * scale]
plt.figure(figsize=(12, 12))
plt.imshow(visual_map, cmap='tab20', extent=extent, origin='lower')
plt.xlabel("X (meters)")
plt.ylabel("Y (meters)")
# --- 绘制bbox，与visual_map在世界坐标下绝对对齐
for inst in result_list:
    bbox = inst['bbox_xywh_m']
    x_left, y_bottom, w_box, h_box = map(float, bbox)
    plt.gca().add_patch(
        plt.Rectangle((x_left, y_bottom), w_box, h_box,
                      edgecolor='red', fill=False, linewidth=1)
    )
plt.title('2D Semantic Map')
plt.savefig('2D_Semantic_Map_839920_Physic.png', bbox_inches='tight', dpi=300)
plt.show()

# ========== 输出物理世界json ==========
with open('2D_Semantic_Map_839920_Physic.json', 'w') as f:
    json.dump(result_list, f, indent=2)
print(f"写出语义物体数：{len(result_list)}")