import os
import json
import numpy as np
import matplotlib.pyplot as plt
import heapq
import math
from matplotlib import cm
from scipy.ndimage import binary_dilation, generate_binary_structure, distance_transform_edt
from skimage.morphology import disk


# ===== 路径与数据 =====
scene_dir = '/home/sig/sig/qianluo/qianluo/3DGS_VLN_Benchmark/test/839873'
with open(os.path.join(scene_dir, '2D_Semantic_Map_839873_Complete_v2.json')) as f:
    data = json.load(f)


# ===== 统计全局范围 =====
all_y = [float(y) for inst in data for y, x in inst['mask_coords_m']]
all_x = [float(x) for inst in data for y, x in inst['mask_coords_m']]
min_y, max_y = min(all_y), max(all_y)
min_x, max_x = min(all_x), max(all_x)
scale = 0.05
h = int(np.ceil((max_y - min_y) / scale)) + 1
w = int(np.ceil((max_x - min_x) / scale)) + 1

def world2pix(x, y):
    px = int(round((float(x) - min_x) / scale))
    py = int(round((float(y) - min_y) / scale))
    return py, px


# ===== 预分类集合 =====
door_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'door'}
ceiling_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'ceiling'}
wall_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'wall'}
unable_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'unable area'}
skip_cids = door_cids.union(ceiling_cids)  # 门和天花板不遮障碍


# ===== 填充语义图/实例图 =====
semantic_map = np.zeros((h, w), dtype=int)
instance_map = np.zeros((h, w), dtype=int)
for idx, inst in enumerate(data):
    cid = inst['category_id']
    pixel_coords = []
    for y, x in inst['mask_coords_m']:
        py, px = world2pix(x, y)
        if 0 <= py < h and 0 <= px < w:
            semantic_map[py, px] = cid
            instance_map[py, px] = idx + 1
            pixel_coords.append((py, px))
    inst['mask_coords'] = pixel_coords


# ===== 构造颜色映射 =====
# 所有类别（非背景0）
all_cids = sorted({inst['category_id'] for inst in data} - {0})
visual_cids = [cid for cid in all_cids if cid not in skip_cids]
bg_color = (31/255, 119/255, 180/255, 1.0)  # 背景色

# 大调色板：tab20 + tab20b + tab20c
tab_colors = np.vstack([
    cm.get_cmap('tab20').colors,
    cm.get_cmap('tab20b').colors,
    cm.get_cmap('tab20c').colors
])

# 剔除与背景色过近的颜色
def color_dist(c1, c2): return np.sqrt(((np.array(c1[:3]) - np.array(c2[:3]))**2).sum())
unique_colors = [c for c in tab_colors if color_dist(c, bg_color) > 0.1]
cid2color = {}
for i, cid in enumerate(visual_cids):
    cid2color[cid] = tuple(unique_colors[i % len(unique_colors)])

# 特殊颜色
for wc in wall_cids:
    cid2color[wc] = (158/255, 218/255, 229/255)  # 墙浅蓝
for uc in unable_cids:
    cid2color[uc] = (1.0, 0.0, 0.0)  # UnableArea 红


# ===== 逐实例填色，可视化 =====
color_map_img = np.zeros((h, w, 4), dtype=float)
color_map_img[:, :] = bg_color
for inst in data:
    cid = inst['category_id']
    if cid in skip_cids:
        continue
    label_lower = str(inst.get('category_label', '')).lower()
    color_rgb = cid2color.get(cid, (0.6, 0.6, 0.6))
    alpha = 0.6 if label_lower == 'unable area' else 1.0
    for py, px in inst['mask_coords']:
        color_map_img[py, px] = color_rgb + (alpha,)
img_extent = [min_x, min_x + w * scale, min_y, min_y + h * scale]
plt.figure(figsize=(12, 12))
plt.gca().set_facecolor(bg_color[:3])
plt.imshow(color_map_img, extent=img_extent, origin='lower')

# 实例编号（跳过天花板和UnableArea）
idx2inst = {}
for idx, inst in enumerate(data):
    lab = str(inst.get('category_label', '')).lower()
    if lab in ('ceiling', 'unable area'):
        continue
    mask = np.array(inst.get('mask_coords', []))
    if mask.size == 0: continue
    centroid = mask.mean(axis=0)
    cx = min_x + (centroid[1] + 0.5) * scale
    cy = min_y + (centroid[0] + 0.5) * scale
    plt.text(cx, cy, f"{idx}", color='black', fontsize=10, ha='center', va='center', fontweight='bold')
    idx2inst[idx] = inst
plt.title('2D Semantic Map (auto-colored categories)')
plt.xlabel('X (meters)'); plt.ylabel('Y (meters)')
plt.show()


# ===== 起点终点选择 =====
print("\n可选物体列表（已跳过 ceiling 与 Unable Area）：")
for idx, inst in idx2inst.items():
    print(f"[{idx}] {inst.get('category_label','')} | instance_id: {inst.get('instance_id','')}")
start_idx = int(input("请输入导航起点物体编号 (如 7): "))
goal_idx = int(input("请输入导航终点物体编号 (如 21): "))


# ===== 障碍图构建 =====
robot_radius_m = 0.18
buffer_pixel = int(np.ceil(robot_radius_m / scale))
print(f"【体积膨胀】机器人半径={robot_radius_m}m，像素buffer={buffer_pixel}")
grid_map = np.zeros_like(semantic_map, dtype=np.uint8)
for inst in data:
    cid = inst['category_id']
    lab = str(inst.get('category_label', '')).lower()
    if cid in door_cids or cid in ceiling_cids or lab in ('door','ceiling'):
        continue
    for py, px in inst['mask_coords']:
        grid_map[py, px] = 1

# 使用欧氏距离变换(EDT)按米制进行安全膨胀：可走区域到最近障碍的距离 < robot_radius_m 的像素视为不可走
if robot_radius_m > 0:
    # dist_m 的单位为“米”，因为采样设置为 scale（米/像素）
    dist_m = distance_transform_edt(grid_map == 0, sampling=scale)
    inflated_obstacle = (dist_m <= robot_radius_m).astype(np.uint8)
    grid_map = inflated_obstacle
    print(f"已根据机器人半径进行EDT安全膨胀，最小安全距离={robot_radius_m}米 (使用连续米制阈值)!")

def get_freedom_pixel(instance_mask, base_map, offset=3):
    mask = np.array(instance_mask)
    if mask.shape[0] == 0: return None
    center = np.round(mask.mean(axis=0)).astype(int)
    for dx in range(-offset, offset+1):
        for dy in range(-offset, offset+1):
            y, x = center[0]+dy, center[1]+dx
            if 0 <= y < base_map.shape[0] and 0 <= x < base_map.shape[1]:
                if base_map[y, x] == 0: return (x, y)
    for y, x in mask:
        for d in [(-offset,0),(offset,0),(0,-offset),(0,offset)]:
            y1, x1 = y+d[0], x+d[1]
            if 0 <= y1 < base_map.shape[0] and 0 <= x1 < base_map.shape[1]:
                if base_map[y1, x1] == 0: return (x1, y1)
    return None

# 起终点像素
start_inst = idx2inst[start_idx]; goal_inst = idx2inst[goal_idx]
start_px = get_freedom_pixel(start_inst['mask_coords'], grid_map)
goal_px = get_freedom_pixel(goal_inst['mask_coords'], grid_map)
if not start_px or not goal_px:
    print(f"自动找物体外的起点或终点时失败: start={start_px}, goal={goal_px}")
    exit(0)
print(f"A*起点像素坐标: {start_px}\n终点像素坐标: {goal_px}")


# ===== A* =====
def astar_pixel(grid, start, goal):
    H, W = grid.shape
    dirs = [(-1,0),(1,0),(0,-1),(0,1),(-1,-1),(-1,1),(1,-1),(1,1)]
    open_set = [(0, start)]
    came_from = {}
    g_score = {start: 0}
    while open_set:
        _, cur = heapq.heappop(open_set)
        if cur == goal:
            path = [cur]
            while cur in came_from:
                cur = came_from[cur]; path.append(cur)
            return path[::-1]
        for d in dirs:
            nx, ny = cur[0]+d[0], cur[1]+d[1]
            if not (0 <= nx < W and 0 <= ny < H): continue
            if grid[ny, nx] == 1: continue
            tg = g_score[cur] + np.linalg.norm(np.array((nx, ny))-np.array(cur))
            nb = (nx, ny)
            if nb not in g_score or tg < g_score[nb]:
                came_from[nb] = cur
                g_score[nb] = tg
                f = tg + np.linalg.norm(np.array((nx, ny))-np.array(goal))
                heapq.heappush(open_set, (f, nb))
    return None
path = astar_pixel(grid_map, start_px, goal_px)

# ===== 绘制路径 =====
plt.figure(figsize=(12, 12))
plt.gca().set_facecolor(bg_color[:3])
plt.imshow(color_map_img, extent=img_extent, origin='lower')

for idx, inst in idx2inst.items():
    mask = np.array(inst['mask_coords'])
    if mask.size == 0: continue
    centroid = mask.mean(axis=0)
    cx = min_x + (centroid[1] + 0.5) * scale
    cy = min_y + (centroid[0] + 0.5) * scale
    plt.text(cx, cy, f"{idx}", color='black', fontsize=10, ha='center', va='center', fontweight='bold')

if path:
    xs, ys = zip(*path)
    xs_w = [min_x + (x + 0.5) * scale for x in xs]
    ys_w = [min_y + (y + 0.5) * scale for y in ys]
    plt.plot(xs_w, ys_w, '-', color='red', linewidth=3, alpha=0.8)
    plt.scatter([xs_w[0], xs_w[-1]], [ys_w[0], ys_w[-1]], color='red', s=80)
    print("A*路径点数:", len(path))
    world_path = list(zip(xs_w, ys_w))
else:
    print("未找到路径"); exit(0)

plt.title('2D Semantic Map with Path')
plt.xlabel('X (meters)'); plt.ylabel('Y (meters)')
plt.show()


# ===== 保存轨迹 =====
sample_step = 1
sampled_world_points = world_path[::sample_step]
fixed_z = 0.5
points_list = []

for i,(wx,wy) in enumerate(sampled_world_points):
    if i < len(sampled_world_points)-1:
        wx_n, wy_n = sampled_world_points[i+1]
    else:
        wx_n, wy_n = sampled_world_points[i]
    dx, dy = wx_n-wx, wy_n-wy
    yaw = math.atan2(dy, dx)
    qx, qy = 0.0, 0.0
    qz = math.sin(yaw/2.0)
    qw = math.cos(yaw/2.0)
    points_list.append({
        "point": str(i),
        "position": [wx, wy, fixed_z],
        "rotation": [qx, qy, qz, qw],
        "action": [],
        "camera_images": [],
        "focal_length": 7.0,
        "horizontal_aperture": 20.95,
        "vertical_aperture": 20.95,
        "focus_distance": 0.0,
        "clipping_range": [1.0, 1000000.0]
    })
traj_json = {
    "dataset_metadata": {"name": "GVLN","dataset_type": "dataset_type","dataset_description": "dataset_description"},
    "scenes": [{
        "scene_id": 0,
        "scene_name": os.path.basename(scene_dir),
        "samples": [{
            "trajectory_id": "0",
            "instructions": [],
            "points": points_list
        }]
    }]
}

out_path = os.path.join(scene_dir, "839873_trajectory_traj_1.json")
with open(out_path,"w") as f: json.dump(traj_json, f, indent=2)
print("轨迹已保存 ->", out_path)