import os
import json
import numpy as np
import matplotlib.pyplot as plt
import heapq
from matplotlib import cm
from scipy.ndimage import binary_dilation, generate_binary_structure

scene_dir = 'test/839919'
with open(os.path.join(scene_dir, '2D_Semantic_Map_839919_Merged_Unable.json')) as f:
    data = json.load(f)

all_y = [float(y) for inst in data for y, x in inst['mask_coords_m']]
all_x = [float(x) for inst in data for y, x in inst['mask_coords_m']]
min_y, max_y = min(all_y), max(all_y)
min_x, max_x = min(all_x), max(all_x)
scale = 0.05
h = int(np.ceil((max_y - min_y) / scale)) + 1
w = int(np.ceil((max_x - min_x) / scale)) + 1

def world2pix(x, y):
    px = int(round((float(x) - min_x) / scale))
    py = int(round((float(y) - min_y) / scale))
    return py, px

semantic_map = np.zeros((h, w), dtype=int)
instance_map = np.zeros((h, w), dtype=int)
idx2inst = {}

for idx, inst in enumerate(data):
    cid = inst['category_id']
    pixel_coords = []
    for y, x in inst['mask_coords_m']:
        py, px = world2pix(x, y)
        if 0 <= py < h and 0 <= px < w:
            semantic_map[py, px] = cid
            instance_map[py, px] = idx + 1
            pixel_coords.append((py, px))
    inst['mask_coords'] = pixel_coords
    idx2inst[idx] = inst

# ===== 配色（背景蓝 + 墙固定 + 门透明） =====
cat_ids = sorted(set(int(inst['category_id']) for inst in data))
max_cat_id = max(cat_ids)
base_colors = cm.get_cmap('tab20', max_cat_id + 1)(range(max_cat_id + 1))[:, :3]
colors = [tuple(ch * 0.75 + 0.25 for ch in c) for c in base_colors]
wall_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'wall'}
for wc in wall_cids:
    if wc <= max_cat_id:
        colors[wc] = (158/255, 218/255, 229/255)  # #9EDAE5
door_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'door'}
bg_color = (31/255, 119/255, 180/255, 1.0)  # 蓝色背景
color_map_img = np.zeros((h, w, 4), dtype=float)
color_map_img[:, :] = bg_color  # 填充背景
for cid in cat_ids:
    mask = (semantic_map == cid)
    if cid in door_cids:
        color_map_img[mask] = (bg_color[0], bg_color[1], bg_color[2], 0.0)
    else:
        col = colors[cid % len(colors)] + (1.0,)
        color_map_img[mask] = col

# ===== 可视化 带编号 =====
img_extent = [min_x, min_x + w * scale, min_y, min_y + h * scale]
plt.figure(figsize=(12, 12))
plt.gca().set_facecolor(bg_color[:3])
plt.imshow(color_map_img, extent=img_extent, origin='lower')
for idx, inst in idx2inst.items():
    mask = np.array(inst['mask_coords'])
    if mask.size == 0:
        continue
    centroid = mask.mean(axis=0)
    centroid_world_x = min_x + (centroid[1] + 0.5) * scale
    centroid_world_y = min_y + (centroid[0] + 0.5) * scale
    plt.text(centroid_world_x, centroid_world_y, f"{idx}",
             color='black', fontsize=10, ha='center', va='center', fontweight='bold')
plt.title('2D Semantic Map with Instance Index')
plt.xlabel('X (meters)')
plt.ylabel('Y (meters)')
plt.savefig(os.path.join(scene_dir, '2D_Semantic_Map_839919_with_idx.png'), bbox_inches='tight', dpi=300)
plt.show()

# ===== 用户选择起点终点 =====
print("\n可选物体列表：")
for idx, inst in idx2inst.items():
    print(f"[{idx}] {inst.get('category_label','')}|instance_id:{inst.get('instance_id','')}")
start_idx = int(input("请输入导航起点物体编号 (如 7): "))
goal_idx = int(input("请输入导航终点物体编号 (如 21): "))

# ===== 构障碍图，排除门 =====
robot_radius_m = 0.25
buffer_pixel = int(np.ceil(robot_radius_m / scale))
print(f"【体积膨胀】机器人半径={robot_radius_m}m，像素buffer={buffer_pixel}")
grid_map = np.zeros_like(semantic_map, dtype=np.uint8)
for inst in data:
    if inst['category_id'] in door_cids:
        continue
    for y, x in inst['mask_coords_m']:
        py, px = world2pix(x, y)
        if 0 <= py < h and 0 <= px < w:
            grid_map[py, px] = 1
if buffer_pixel > 0:
    struct = generate_binary_structure(2, 1)
    struct = binary_dilation(struct, iterations=buffer_pixel - 1)
    grid_map = binary_dilation(grid_map, structure=struct).astype(np.uint8)
    print(f'已对障碍膨胀buffer {buffer_pixel}格!')

def get_freedom_pixel(instance_mask, base_map, offset=3):
    mask = np.array(instance_mask)
    if mask.shape[0] == 0:
        return None
    center = np.round(mask.mean(axis=0)).astype(int)
    for dx in range(-offset, offset+1):
        for dy in range(-offset, offset+1):
            y, x = center[0] + dy, center[1] + dx
            if 0 <= y < base_map.shape[0] and 0 <= x < base_map.shape[1]:
                if base_map[y, x] == 0:
                    return (x, y)
    for y, x in mask:
        for d in [(-offset,0),(offset,0),(0,-offset),(0,offset)]:
            y1, x1 = y+d[0], x+d[1]
            if 0 <= y1 < base_map.shape[0] and 0 <= x1 < base_map.shape[1]:
                if base_map[y1, x1] == 0:
                    return (x1, y1)
    return None

start_inst = idx2inst[start_idx]
goal_inst = idx2inst[goal_idx]
start_px = get_freedom_pixel(start_inst['mask_coords'], grid_map)
goal_px = get_freedom_pixel(goal_inst['mask_coords'], grid_map)
if not start_px or not goal_px:
    print(f"自动找物体外的起点或终点时失败: start={start_px}, goal={goal_px}")
    exit(0)
print(f"A*起点像素坐标: {start_px}\n终点像素坐标: {goal_px}")

def astar_pixel(grid, start_px, goal_px):
    h, w = grid.shape
    dirs = [(-1,0),(1,0),(0,-1),(0,1),(-1,-1),(-1,1),(1,-1),(1,1)]
    open_set = []
    heapq.heappush(open_set, (0, start_px))
    came_from = {}
    g_score = {start_px: 0}
    while open_set:
        _, current = heapq.heappop(open_set)
        if current == goal_px:
            path = [current]
            while current in came_from:
                current = came_from[current]
                path.append(current)
            return path[::-1]
        for d in dirs:
            neighbor = (current[0]+d[0], current[1]+d[1])
            if not (0 <= neighbor[1] < h and 0 <= neighbor[0] < w):
                continue
            if grid[neighbor[1], neighbor[0]] == 1:
                continue
            tentative_g = g_score[current] + np.linalg.norm(np.array(neighbor) - np.array(current))
            if neighbor not in g_score or tentative_g < g_score[neighbor]:
                came_from[neighbor] = current
                g_score[neighbor] = tentative_g
                f = tentative_g + np.linalg.norm(np.array(neighbor) - np.array(goal_px))
                heapq.heappush(open_set, (f, neighbor))
    return None

path = astar_pixel(grid_map, start_px, goal_px)

# ===== 绘制路径，并加物体编号 =====
plt.figure(figsize=(12, 12))
plt.gca().set_facecolor(bg_color[:3])
plt.imshow(color_map_img, extent=img_extent, origin='lower')
# 修正：这里重新取 mask = np.array(inst['mask_coords'])
for idx, inst in idx2inst.items():
    mask = np.array(inst['mask_coords'])
    if mask.size == 0:
        continue
    centroid = mask.mean(axis=0)
    centroid_world_x = min_x + (centroid[1] + 0.5) * scale
    centroid_world_y = min_y + (centroid[0] + 0.5) * scale
    plt.text(centroid_world_x, centroid_world_y, f"{idx}",
             color='black', fontsize=10, ha='center', va='center', fontweight='bold')

if path:
    xs, ys = zip(*path)
    xs_w = [min_x + (x + 0.5) * scale for x in xs]
    ys_w = [min_y + (y + 0.5) * scale for y in ys]
    plt.plot(xs_w, ys_w, '-', color='red', linewidth=3, alpha=0.8)
    plt.scatter([xs_w[0], xs_w[-1]], [ys_w[0], ys_w[-1]], color='red', s=80, label='start/goal')
    print("A*路径点数:", len(path))
else:
    print("A*未找到路径！")

plt.title('2D Semantic Map - Mask-based Navigation (doors traversable)')
plt.xlabel('X (meters)')
plt.ylabel('Y (meters)')
plt.show()