import json
import numpy as np
import matplotlib.pyplot as plt
import heapq

# 1. 读取语义json和生成mask map
with open('2D_Semantic_Map_839920.json') as f:
    data = json.load(f)

h = max([max((coords[0] for coords in inst['mask_coords']), default=0) for inst in data]) + 1
w = max([max((coords[1] for coords in inst['mask_coords']), default=0) for inst in data]) + 1
semantic_map = np.zeros((h, w), dtype=int)
instance_map = np.zeros((h, w), dtype=int)

# 建立实例编号和json记录的对照表
idx2inst = {}
for idx, inst in enumerate(data):
    cid = inst['category_id']
    for y, x in inst['mask_coords']:
        semantic_map[y, x] = cid
        instance_map[y, x] = idx + 1   # instance_map便于可选实例mask的查找
    idx2inst[idx] = inst



# 2. 显示语义Map和物体mask编号
plt.figure(figsize=(12,12))
plt.imshow(semantic_map[:, ::-1], cmap='tab20')  # 全图镜像
for idx, inst in idx2inst.items():
    mask = np.array(inst['mask_coords'])
    if mask.size == 0: continue
    centroid = mask.mean(axis=0)[::-1]
    # 翻转x坐标
    centroid[0] = w - 1 - centroid[0]
    plt.text(centroid[0], centroid[1], f"{idx}", color='black', fontsize=10, ha='center', va='center', fontweight='bold')
plt.title('2D Semantic Map with Instance Index')
plt.savefig('2D_Semantic_Map_839920_with_idx.png', bbox_inches='tight', dpi=300)
plt.show()



# 3. 交互提示实例号
print("\n可选物体列表：")
for idx, inst in idx2inst.items():
    showid = inst.get('instance_id', '')
    print(f"[{idx}] {inst['category_label']}|instance_id:{showid}")

start_idx = int(input("请输入导航起点物体编号 (如 7): "))
goal_idx = int(input("请输入导航终点物体编号 (如 21): "))



# 4. 找mask外缘的自由像素点作为起点终点
def get_freedom_pixel(instance_mask, base_map, offset=3):
    # 从mask里均匀随机/或几何中心，向mask外扫描offset像素，找第一个空地像素返回
    mask = np.array(instance_mask)
    if mask.shape[0] == 0: return None
    center = np.round(mask.mean(axis=0)).astype(int)
    for dx in range(-offset, offset+1):
        for dy in range(-offset, offset+1):
            y, x = center[0] + dy, center[1] + dx
            if 0 <= y < base_map.shape[0] and 0 <= x < base_map.shape[1]:
                # 如果该像素没被mask而且非障碍
                if base_map[y, x] == 0:
                    return (x, y)
    # 没找到再从bbox四边扩
    for y, x in mask:
        for d in [(-offset,0),(offset,0),(0,-offset),(0,offset)]:
            y1, x1 = y+d[0], x+d[1]
            if 0 <= y1 < base_map.shape[0] and 0 <= x1 < base_map.shape[1]:
                if base_map[y1, x1] == 0:
                    return (x1, y1)
    return None



# 5. 0为可通行，非0为障碍
grid_map = np.zeros_like(semantic_map, dtype=np.uint8)
grid_map[semantic_map != 0] = 1

start_inst = idx2inst[start_idx]
goal_inst = idx2inst[goal_idx]
start_px = get_freedom_pixel(start_inst['mask_coords'], grid_map)
goal_px = get_freedom_pixel(goal_inst['mask_coords'], grid_map)
if not start_px or not goal_px:
    print(f"自动找物体外的起点或终点时失败: start={start_px}, goal={goal_px}")
    exit(0)

print(f"A*起点像素坐标: {start_px}\n终点像素坐标: {goal_px}")



# 6. 像素转路径点并A*搜索
def astar_pixel(grid, start_px, goal_px):
    h, w = grid.shape
    dirs = [(-1,0),(1,0),(0,-1),(0,1),(-1,-1),(-1,1),(1,-1),(1,1)]
    open_set = []
    heapq.heappush(open_set, (0, start_px))
    came_from = {}
    g_score = {start_px: 0}
    f_score = {start_px: np.linalg.norm(np.array(start_px) - np.array(goal_px))}
    while open_set:
        _, current = heapq.heappop(open_set)
        if current == goal_px:
            # 回溯
            path = [current]
            while current in came_from:
                current = came_from[current]
                path.append(current)
            return path[::-1]
        for d in dirs:
            neighbor = (current[0]+d[0], current[1]+d[1])
            if not (0 <= neighbor[1] < h and 0 <= neighbor[0] < w):
                continue
            if grid[neighbor[1], neighbor[0]] == 1:  # 注意[y][x]
                continue
            tentative_g = g_score[current] + np.linalg.norm(np.array(neighbor)-np.array(current))
            if neighbor not in g_score or tentative_g < g_score[neighbor]:
                came_from[neighbor] = current
                g_score[neighbor] = tentative_g
                f = tentative_g + np.linalg.norm(np.array(neighbor)-np.array(goal_px))
                f_score[neighbor] = f
                heapq.heappush(open_set, (f, neighbor))
    return None

path = astar_pixel(grid_map, start_px, goal_px)



# 7. 可视化最终path
plt.figure(figsize=(12,12))
plt.imshow(semantic_map[:, ::-1], cmap='tab20')
for idx, inst in idx2inst.items():
    mask = np.array(inst['mask_coords'])
    if mask.size == 0: continue
    centroid = mask.mean(axis=0)[::-1]
    centroid[0] = w - 1 - centroid[0]
    plt.text(centroid[0], centroid[1], f"{idx}", color='black', fontsize=10, ha='center', va='center', fontweight='bold')
if path:
    xs, ys = zip(*path)
    xs_flip = [w - 1 - x for x in xs]
    plt.plot(xs_flip, ys, '-', color='red', linewidth=3, alpha=0.8)
    plt.scatter([xs_flip[0], xs_flip[-1]], [ys[0], ys[-1]], color='red', s=80, label='start/goal')
    print("A*路径点数:", len(path))
else:
    print("A*未找到路径！")
plt.title('2D Semantic Map - Mask-based Navigation')
plt.show()