import os
import numpy as np
import json
from PIL import Image
from shapely.geometry import Polygon, Point


def format2(x):
    return "{:.2f}".format(float(x))

root_dir = "./Dataset/InteriorGS"
output_root = "./results"
os.makedirs(output_root, exist_ok=True)

for subdir in sorted(os.listdir(root_dir)):
    dir_path = os.path.join(root_dir, subdir)

    if not os.path.isdir(dir_path):
        continue
    try:
        id_str = subdir.split("_")[-1]
    except:
        continue

    out_json = os.path.join(output_root, f"2D_Semantic_Map_{id_str}.json")
    if os.path.exists(out_json):
        print(f"已存在 {out_json}，跳过")
        continue

    occ_json_path = os.path.join(dir_path, "occupancy.json")
    label_json_path = os.path.join(dir_path, "labels.json")
    if not (os.path.isfile(occ_json_path) and os.path.isfile(label_json_path)):
        print(f"跳过 {subdir}, 缺少必需的json文件")
        continue

    # ====== 1. 读取Occupancy map参数 ======
    with open(occ_json_path, "r") as f:
        meta = json.load(f)
    scale = meta["scale"]
    x_min, y_min = meta["min"][:2]
    x_max, y_max = meta["max"][:2]
    occ_img_path = os.path.join(dir_path, "occupancy.png")
    occ_img = Image.open(occ_img_path)
    occupancy = np.array(occ_img)
    h, w = occupancy.shape

    # ====== 2. 读取3DGS物体标注 ======
    with open(label_json_path, "r") as f:
        labels = json.load(f)
    label2id = {}
    id2label = {}
    label_cur = 1

    for obj in labels:
        label = obj["label"]
        if label not in label2id:
            label2id[label] = label_cur
            id2label[label_cur] = label
            label_cur += 1
    result_list = []

    for obj in labels:
        if "bounding_box" not in obj: continue
        label = obj["label"]
        cat_id = label2id[label]
        poly3d = obj["bounding_box"]
        poly2d = [[v["x"], v["y"]] for v in poly3d[:4]]
        poly = Polygon(poly2d)
        xys = np.array(poly2d)

        min_x_pixel = int(np.floor((np.min(xys[:, 0]) - x_min) / scale))
        max_x_pixel = int(np.floor((np.max(xys[:, 0]) - x_min) / scale))
        min_y_pixel = int(np.floor((np.min(xys[:, 1]) - y_min) / scale))
        max_y_pixel = int(np.floor((np.max(xys[:, 1]) - y_min) / scale))
        min_x_pixel = np.clip(min_x_pixel, 0, w - 1)
        max_x_pixel = np.clip(max_x_pixel, 0, w - 1)
        min_y_pixel = np.clip(min_y_pixel, 0, h - 1)
        max_y_pixel = np.clip(max_y_pixel, 0, h - 1)
        mask = np.zeros((h, w), dtype=bool)

        for j in range(min_x_pixel, max_x_pixel + 1):
            for i in range(min_y_pixel, max_y_pixel + 1):
                i_flip = h - 1 - i
                j_flip = w - 1 - j
                cx = x_min + (j + 0.5) * scale
                cy = y_min + (i + 0.5) * scale
                p = Point(cx, cy)
                if poly.covers(p):
                    mask[i_flip, j_flip] = True
        ys, xs = np.where(mask)

        if xs.size == 0 or ys.size == 0:
            continue

        xmin_pix, xmax_pix = min(xs), max(xs)
        ymin_pix, ymax_pix = min(ys), max(ys)
        x_left = x_min + xmin_pix * scale
        x_right = x_min + (xmax_pix + 1) * scale
        y_bottom = y_min + ymin_pix * scale
        y_top = y_min + (ymax_pix + 1) * scale
        w_box = x_right - x_left
        h_box = y_top - y_bottom
        bbox_m = [
            format2(x_left),
            format2(y_bottom),
            format2(x_right),
            format2(y_top)
        ]
        bbox_xywh_m = [
            format2(x_left),
            format2(y_bottom),
            format2(w_box),
            format2(h_box)
        ]
        mask_coords_m = [
            [format2(y_min + (y + 0.5) * scale), format2(x_min + (x + 0.5) * scale)]
            for y, x in zip(ys, xs)
        ]
        result_list.append({
            "category_id": int(cat_id),
            "category_label": label,
            "instance_id": obj.get("ins_id", ""),
            "bbox_m": bbox_m,
            "bbox_xywh_m": bbox_xywh_m,
            "area": int(mask.sum()),
            "mask_coords_m": mask_coords_m
        })

    with open(out_json, "w") as f:
        json.dump(result_list, f, indent=2)
    print(f"已处理: {out_json}")