import os
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap

scene_dir = 'test/839919'
# --- 选定分辨率（单位：米）
with open(os.path.join(scene_dir, 'occupancy.json')) as f:
    meta = json.load(f)
scale = meta['scale']

# --- 读取物理世界的2D mask信息
with open(os.path.join(scene_dir, '2D_Semantic_Map_839919_Physic_Walls.json')) as f:
    data = json.load(f)

# --- 自动推断物理世界mask的坐标边界
all_y = [float(y) for inst in data for y, x in inst['mask_coords_m']]
all_x = [float(x) for inst in data for y, x in inst['mask_coords_m']]
min_y, max_y = min(all_y), max(all_y)
min_x, max_x = min(all_x), max(all_x)
h = int(np.ceil((max_y - min_y) / scale)) + 1
w = int(np.ceil((max_x - min_x) / scale)) + 1

img = np.zeros((h, w), dtype=np.int32)

# --- 物理世界 -> mask像素下标
def world2pix(x, y):
    """
    输入：x, y为浮点物理世界坐标(单位: 米)
    输出：py(img行), px(img列)
    """
    px = int(round((float(x) - min_x) / scale))
    py = int(round((float(y) - min_y) / scale))
    return py, px


# ----------- 关键! 类别到色表索引连续重映射 ---------------
cat_set = set(int(inst['category_id']) for inst in data)
cat_list = sorted(cat_set)
cat2idx = {cat: idx for idx, cat in enumerate(cat_list)}
n_color = len(cat2idx)

for inst in data:
    cid = int(inst['category_id'])
    color_idx = cat2idx[cid]
    for y, x in inst['mask_coords_m']:
        py, px = world2pix(x, y)
        if 0 <= py < h and 0 <= px < w:
            img[py, px] = color_idx


# 拷贝tab20的颜色数组
tab20 = cm.get_cmap('tab20')
colors = tab20.colors  # 这是一个tuple，包含20个RGBA颜色

# 转成list好修改
colors = list(colors)

# 修改颜色
r, g, b = 255, 237, 111
colors[6] = (r/255, g/255, b/255, 1.0)
for i in range(1, 10):
    colors[i] = colors[i + 1]

# 创建新的colormap
new_cmap = ListedColormap(colors)


plt.figure(figsize=(12, 12))
img_extent = [min_x, min_x + w * scale, min_y, min_y + h * scale]
plt.imshow(img, cmap=new_cmap, extent=img_extent, origin='lower', vmin=1, vmax=n_color+1)

# --- 画bbox（均已是物理世界坐标，直接用）
for inst in data:
    bbox = inst['bbox_m']
    x0, y0, x1, y1 = map(float, bbox)
    plt.gca().add_patch(
        plt.Rectangle((x0, y0 + 0.5 * scale), x1 - x0, y1 - y0,
                      edgecolor='red', fill=False, linewidth=1)
    )
plt.title('2D Semantic Map (Restored from Physical JSON)')
plt.xlabel('X (meters)')
plt.ylabel('Y (meters)')
plt.savefig(os.path.join(scene_dir, '2D_Semantic_Map_839919_Physic_Restored.png'), bbox_inches='tight', dpi=300)
plt.show()