import os
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap

scene_dir = 'test/839920'

# --- 读取 scale
with open(os.path.join(scene_dir, 'occupancy.json')) as f:
    meta = json.load(f)
scale = meta['scale']

# --- 读取保存好的语义 JSON
with open(os.path.join(scene_dir, '2D_Semantic_Map_839920_with_Walls.json')) as f:
    data = json.load(f)

# --- 自动推断物理世界的坐标范围
all_y = [float(y) for inst in data for y, x in inst['mask_coords_m']]
all_x = [float(x) for inst in data for y, x in inst['mask_coords_m']]
min_y, max_y = min(all_y), max(all_y)
min_x, max_x = min(all_x), max(all_x)

# --- 分辨率（像素行列数）
h = int(np.ceil((max_y - min_y) / scale)) + 1
w = int(np.ceil((max_x - min_x) / scale)) + 1

img = np.zeros((h, w), dtype=np.int32)  # 存类别 ID, 0=背景

# --- 物理坐标 -> 像素坐标
def world2pix(x, y):
    px = int(round((float(x) - min_x) / scale))
    py = int(round((float(y) - min_y) / scale))
    return py, px

# ----------- 类别 ID 直接用 JSON 的，不重新压缩映射 ------------
cat_ids = [int(inst['category_id']) for inst in data]
max_cat_id = max(cat_ids)

# --- 绘制 mask 到图像矩阵
for inst in data:
    cid = int(inst['category_id'])
    for y, x in inst['mask_coords_m']:
        py, px = world2pix(x, y)
        if 0 <= py < h and 0 <= px < w:
            img[py, px] = cid

# --- 建立颜色表
base_colors = cm.get_cmap('tab20', max_cat_id + 1).colors
colors = list(base_colors)
# 柔化颜色
colors = [tuple(c[i] * 0.75 + 0.25 for i in range(3)) + (1.0,) for c in colors]

# 固定墙体颜色：找到 label 为 wall 的 category_id
wall_rgb = (158/255, 218/255, 229/255, 1.0)  # 浅蓝
wall_ids = [inst['category_id'] for inst in data if inst['category_label'].lower() == 'wall']
if wall_ids:
    wall_id = int(wall_ids[0])
    if wall_id <= max_cat_id:
        colors[wall_id] = wall_rgb

new_cmap = ListedColormap(colors)

# --- 绘图
plt.figure(figsize=(12, 12))
img_extent = [min_x, min_x + w * scale, min_y, min_y + h * scale]
plt.imshow(img, cmap=new_cmap, extent=img_extent, origin='lower', vmin=0, vmax=max_cat_id)

# --- 画 bbox（JSON 里已是物理坐标）
for inst in data:
    bbox = inst['bbox_m']
    x0, y0, x1, y1 = map(float, bbox)
    plt.gca().add_patch(
        plt.Rectangle((x0, y0 + 0.5 * scale), x1 - x0, y1 - y0,
                      edgecolor='red', fill=False, linewidth=1)
    )

plt.title('2D Semantic Map (Restored from Physical JSON)')
plt.xlabel('X (meters)')
plt.ylabel('Y (meters)')
plt.savefig(os.path.join(scene_dir, '2D_Semantic_Map_839920_Physic_Restored_v2.png'),
            bbox_inches='tight', dpi=300)
plt.show()