import os
import json
import numpy as np
import heapq
import math
import matplotlib
matplotlib.use('Agg')  # 无显示环境时也能保存图片
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy.ndimage import distance_transform_edt

# ===== 配置路径 =====
scene_dir = '/home/sig/sig/qianluo/qianluo/3DGS_VLN_Benchmark/test/839873'
semantic_map_json = os.path.join(scene_dir, '2D_Semantic_Map_839873_Complete_v3.json')
input_traj_json = '/home/sig/sig/qianluo/qianluo/github/GSNav-Bench/Code/Benchmark_Environment/Trajectory/trajectories_0033_839873.json'
input_traj = os.path.basename(input_traj_json)
# 可视化输出目录
vis_dir = os.path.join(scene_dir, 'recollected_nav_vis')
os.makedirs(vis_dir, exist_ok=True)

# ===== 读取2D Semantic Map数据 =====
with open(semantic_map_json, 'r') as f:
    data = json.load(f)

# ===== 统计全局范围（米制）=====
all_y = [float(y) for inst in data for y, x in inst['mask_coords_m']]
all_x = [float(x) for inst in data for y, x in inst['mask_coords_m']]
min_y, max_y = min(all_y), max(all_y)
min_x, max_x = min(all_x), max(all_x)

# 栅格分辨率（米/像素）
scale = 0.05

# 栅格尺寸
h = int(np.ceil((max_y - min_y) / scale)) + 1
w = int(np.ceil((max_x - min_x) / scale)) + 1

def world2pix(x, y):
    """世界坐标(m) -> 栅格像素(y, x)"""
    px = int(round((float(x) - min_x) / scale))
    py = int(round((float(y) - min_y) / scale))
    return py, px

def pix2world(px, py):
    """栅格像素(x, y) -> 世界坐标(m)"""
    wx = min_x + (px + 0.5) * scale
    wy = min_y + (py + 0.5) * scale
    return wx, wy

# ===== 预分类集合 =====
door_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'door'}
ceiling_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'ceiling'}
wall_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'wall'}
unable_cids = {inst['category_id'] for inst in data if str(inst.get('category_label', '')).lower() == 'unable area'}
skip_cids = door_cids.union(ceiling_cids)  # 门和天花板不遮障碍

# ===== 填充语义图/实例图，并构建 item_id -> 实例 的索引 =====
semantic_map = np.zeros((h, w), dtype=int)
instance_map = np.zeros((h, w), dtype=int)
itemid2inst = {}

for idx, inst in enumerate(data):
    cid = inst['category_id']
    pixel_coords = []
    for y, x in inst['mask_coords_m']:
        py, px = world2pix(x, y)
        if 0 <= py < h and 0 <= px < w:
            semantic_map[py, px] = cid
            instance_map[py, px] = idx + 1
            pixel_coords.append((py, px))
    inst['mask_coords'] = pixel_coords
    if 'item_id' in inst:
        itemid2inst[str(inst['item_id'])] = inst

# ===== 颜色映射与可视化底图构建（用于保存2D导航图）=====
# 所有类别（非背景0）
all_cids = sorted({inst['category_id'] for inst in data} - {0})
visual_cids = [cid for cid in all_cids if cid not in skip_cids]
bg_color = (31/255, 119/255, 180/255, 1.0)  # 背景色
# 大调色板：tab20 + tab20b + tab20c
tab_colors = np.vstack([
    cm.get_cmap('tab20').colors,
    cm.get_cmap('tab20b').colors,
    cm.get_cmap('tab20c').colors
])
def color_dist(c1, c2): return np.sqrt(((np.array(c1[:3]) - np.array(c2[:3]))**2).sum())
unique_colors = [c for c in tab_colors if color_dist(c, bg_color) > 0.1]
cid2color = {}
for i, cid in enumerate(visual_cids):
    cid2color[cid] = tuple(unique_colors[i % len(unique_colors)])
# 特殊颜色
for wc in wall_cids:
    cid2color[wc] = (158/255, 218/255, 229/255)  # 墙浅蓝
for uc in unable_cids:
    cid2color[uc] = (1.0, 0.0, 0.0)  # UnableArea 红
# 构建彩色底图
color_map_img = np.zeros((h, w, 4), dtype=float)
color_map_img[:, :] = bg_color
for inst in data:
    cid = inst['category_id']
    if cid in skip_cids:
        continue
    label_lower = str(inst.get('category_label', '')).lower()
    color_rgb = cid2color.get(cid, (0.6, 0.6, 0.6))
    alpha = 0.6 if label_lower == 'unable area' else 1.0
    for py, px in inst['mask_coords']:
        color_map_img[py, px] = color_rgb + (alpha,)
img_extent = [min_x, min_x + w * scale, min_y, min_y + h * scale]

# ===== 构建障碍网格并进行EDT安全膨胀 =====
def build_inflated_grid(robot_radius_m):
    """基于语义图构建障碍，并用EDT(米制)按机器人半径进行安全膨胀，返回grid_map(1障碍/0自由)"""
    grid_map = np.zeros_like(semantic_map, dtype=np.uint8)
    for inst in data:
        cid = inst['category_id']
        lab = str(inst.get('category_label', '')).lower()
        # 先跳过门/天花板
        if cid in door_cids or cid in ceiling_cids or lab in ('door', 'ceiling'):
            continue
        # 新增：如果物体底面高于 2.0m，则跳过
        try:
            min_z_val = float(inst.get('min_z_m', '0.0'))
        except:
            min_z_val = 0.0
        if min_z_val > 2.0:
            continue
        # 添加障碍
        for py, px in inst['mask_coords']:
            grid_map[py, px] = 1
    if robot_radius_m > 0:
        dist_m = distance_transform_edt(grid_map == 0, sampling=scale)
        grid_map = (dist_m <= robot_radius_m).astype(np.uint8)
    return grid_map

# ===== 在实例附近找一个可行走像素 =====
def get_freedom_pixel(instance_mask, base_map, offset=5):
    """
    在实例中心附近找一个 base_map==0 的像素，找不到则沿边缘四向偏移搜索。
    offset 可根据 robot 半径调大。
    返回 (x, y) 像素坐标或 None
    """
    mask = np.array(instance_mask)
    if mask.shape[0] == 0:
        return None
    center = np.round(mask.mean(axis=0)).astype(int)
    # 先在中心附近小范围搜索
    for dx in range(-offset, offset + 1):
        for dy in range(-offset, offset + 1):
            y, x = center[0] + dy, center[1] + dx
            if 0 <= y < base_map.shape[0] and 0 <= x < base_map.shape[1]:
                if base_map[y, x] == 0:
                    return (x, y)
    # 再沿边界四方向扩展搜索
    for y, x in mask:
        for d in [(-offset, 0), (offset, 0), (0, -offset), (0, offset)]:
            y1, x1 = y + d[0], x + d[1]
            if 0 <= y1 < base_map.shape[0] and 0 <= x1 < base_map.shape[1]:
                if base_map[y1, x1] == 0:
                    return (x1, y1)
    return None

# ===== 改进：朝向另一端实例的“就近外侧”可行走像素（避免绕到物体背面）=====
def instance_centroid_px(mask_coords):
    m = np.array(mask_coords)
    if m.size == 0:
        return None
    c = m.mean(axis=0)  # (y, x)
    return (int(round(c[1])), int(round(c[0])))  # (x, y)

def boundary_pixels(mask_coords):
    s = set((int(y), int(x)) for (y, x) in mask_coords)
    b = []
    for (y, x) in s:
        if ((y - 1, x) not in s) or ((y + 1, x) not in s) or ((y, x - 1) not in s) or ((y, x + 1) not in s):
            b.append((y, x))
    return b  # list of (y, x)

def bresenham_line(x0, y0, x1, y1):
    """整数网格上的Bresenham直线，返回[(x,y), ...]"""
    points = []
    dx = abs(x1 - x0)
    dy = -abs(y1 - y0)
    sx = 1 if x0 < x1 else -1
    sy = 1 if y0 < y1 else -1
    err = dx + dy
    x, y = x0, y0
    while True:
        points.append((x, y))
        if x == x1 and y == y1:
            break
        e2 = 2 * err
        if e2 >= dy:
            err += dy
            x += sx
        if e2 <= dx:
            err += dx
            y += sy
    return points

from collections import deque

def get_nearest_free_pixel_on_side(instance_mask, base_map, towards_px=None, max_search_dist=50):
    """
    从实例mask边界出发，BFS搜索最近的可行走像素。
    如果towards_px!=None，则优先沿着物体质心指向towards_px的方向挑选最近点，
    以避免绕到背面。
    返回 (x, y) 或 None
    """
    H, W = base_map.shape
    # 边界像素集
    b_pixels = boundary_pixels(instance_mask)  # [(y,x)]
    if not b_pixels:
        return None

    visited = set()
    q = deque()

    # 入队边界像素
    for (by, bx) in b_pixels:
        visited.add((bx, by))
        q.append((bx, by, 0))  # (x, y, dist)

    # BFS 扩散
    while q:
        x, y, d = q.popleft()
        if d > max_search_dist:
            break
        # 判断是否自由
        if base_map[y, x] == 0:
            if towards_px is None:
                return (x, y)
            else:
                # 如果想靠近towards_px，则判断这个点与towards的方向夹角小于90°才选
                bx, by = np.mean([(px, py) for (py, px) in instance_mask], axis=0)
                vec_item_to_point = np.array([x - bx, y - by])
                vec_item_to_towards = np.array([towards_px[0] - bx, towards_px[1] - by])
                if np.dot(vec_item_to_point, vec_item_to_towards) >= 0:
                    return (x, y)
        # 否则继续扩展邻居
        for dx, dy in [(-1,0),(1,0),(0,-1),(0,1)]:
            nx, ny = x + dx, y + dy
            if 0 <= nx < W and 0 <= ny < H and (nx, ny) not in visited:
                visited.add((nx, ny))
                q.append((nx, ny, d+1))
    return None


# ===== A* 寻路（像素网格）=====
def astar_pixel(grid, start, goal):
    """
    grid: HxW, 1=障碍, 0=可行
    start/goal: (x, y)
    返回路径列表[(x,y), ...] 或 None
    """
    H, W = grid.shape
    dirs = [(-1,0),(1,0),(0,-1),(0,1),(-1,-1),(-1,1),(1,-1),(1,1)]
    open_set = [(0, start)]
    came_from = {}
    g_score = {start: 0.0}
    while open_set:
        _, cur = heapq.heappop(open_set)
        if cur == goal:
            path = [cur]
            while cur in came_from:
                cur = came_from[cur]
                path.append(cur)
            return path[::-1]
        for d in dirs:
            nx, ny = cur[0] + d[0], cur[1] + d[1]
            if not (0 <= nx < W and 0 <= ny < H):
                continue
            if grid[ny, nx] == 1:
                continue
            nb = (nx, ny)
            step = math.hypot(nx - cur[0], ny - cur[1])
            tg = g_score[cur] + step
            if nb not in g_score or tg < g_score[nb]:
                came_from[nb] = cur
                g_score[nb] = tg
                f = tg + math.hypot(nx - goal[0], ny - goal[1])
                heapq.heappush(open_set, (f, nb))
    return None

# ===== 读取输入轨迹文件 =====
with open(input_traj_json, 'r') as f:
    input_trajs = json.load(f)

# 解析场景与轨迹
scenes_in = input_trajs.get('scenes', {})
if not scenes_in:
    raise ValueError("输入轨迹JSON中未找到 'scenes' 字段或为空。")

# 导航参数
robot_radius_m = 0.15
grid_map = build_inflated_grid(robot_radius_m)

# ===== 遍历轨迹，按 item_id 导航并生成新轨迹，同时保存可视化 =====
samples_out = []
traj_counter = 0

for scene_key, scene_obj in scenes_in.items():
    # scene_obj 形如 { "trajectory_0": {...}, "trajectory_1": {...}, ... }
    for traj_key, traj_obj in scene_obj.items():
        # 取起止物体 item_id
        se = traj_obj.get('start_end', {})
        obj = se.get('object', {}) if isinstance(se, dict) else {}
        start_item = obj.get('start', None)
        end_item = obj.get('end', None)

        # 取 instruction 列表
        instructions_obj = traj_obj.get('instructions', {})
        def instr_sort_key(k):
            try:
                return int(str(k).split('_')[-1])
            except:
                return str(k)
        instructions = [instructions_obj[k]['content']
                        for k in sorted(instructions_obj.keys(), key=instr_sort_key)
                        if isinstance(instructions_obj.get(k), dict) and 'content' in instructions_obj[k]]

        if not start_item or not end_item:
            print(f"[跳过] 轨迹 {traj_key}: 缺少 start/end item_id。")
            continue

        start_inst = itemid2inst.get(str(start_item))
        goal_inst = itemid2inst.get(str(end_item))
        if start_inst is None or goal_inst is None:
            print(f"[跳过] 轨迹 {traj_key}: 在2D地图中找不到 item_id -> start:{start_item} goal:{end_item}")
            continue

        # 计算两个实例的像素质心
        start_cent_px = instance_centroid_px(start_inst['mask_coords'])
        goal_cent_px = instance_centroid_px(goal_inst['mask_coords'])
        if start_cent_px is None or goal_cent_px is None:
            print(f"[跳过] 轨迹 {traj_key}: 起点或终点实例无有效mask。")
            continue

        # 使用“朝向另一端实例”的就近外侧像素，避免绕到物体背面
        start_px = get_nearest_free_pixel_on_side(start_inst['mask_coords'], grid_map, towards_px=goal_cent_px)
        goal_px = get_nearest_free_pixel_on_side(goal_inst['mask_coords'], grid_map, towards_px=start_cent_px)
        if not start_px or not goal_px:
            print(f"[跳过] 轨迹 {traj_key}: 起点或终点附近找不到可走像素 start={start_px}, goal={goal_px}")
            continue

        # A* 寻路
        path = astar_pixel(grid_map, start_px, goal_px)
        if not path:
            print(f"[跳过] 轨迹 {traj_key}: A* 未找到路径。")
            continue

        # 像素路径 -> 世界坐标，计算朝向四元数
        xs, ys = zip(*path)  # (x,y) 像素
        xs_w = [min_x + (x + 0.5) * scale for x in xs]
        ys_w = [min_y + (y + 0.5) * scale for y in ys]
        world_path = list(zip(xs_w, ys_w))

        # 组装 points
        sample_step = 1
        sampled_world_points = world_path[::sample_step]
        fixed_z = 0.5
        points_list = []
        for i, (wx, wy) in enumerate(sampled_world_points):
            if i < len(sampled_world_points) - 1:
                wx_next, wy_next = sampled_world_points[i + 1]
            else:
                wx_next, wy_next = sampled_world_points[i]
            dx = wx_next - wx
            dy = wy_next - wy
            yaw = math.atan2(dy, dx)
            qx, qy = 0.0, 0.0
            qz = math.sin(yaw / 2.0)
            qw = math.cos(yaw / 2.0)
            points_list.append({
                "point": str(i),
                "position": [wx, wy, fixed_z],
                "rotation": [qx, qy, qz, qw],
                "action": [],
                "camera_images": [],
                "focal_length": 7.0,
                "horizontal_aperture": 20.954999923706055,
                "vertical_aperture": 20.954999923706055,
                "focus_distance": 0.0,
                "clipping_range": [1.0, 1000000.0]
            })

        # 写入一个 sample
        samples_out.append({
            "trajectory_id": str(traj_counter),
            "instructions": instructions,
            "points": points_list
        })
        print(f"[完成] 轨迹 {traj_key} -> 新路径点数: {len(points_list)}, 指令数: {len(instructions)}")

        # ===== 保存可视化 =====
        fig = plt.figure(figsize=(12, 12))
        ax = plt.gca()
        ax.set_facecolor(bg_color[:3])
        ax.imshow(color_map_img, extent=img_extent, origin='lower')

        # 标注起终点文本
        def inst_centroid_world(inst):
            m = np.array(inst.get('mask_coords', []))
            if m.size == 0:
                return None
            c = m.mean(axis=0)  # (y, x)
            cx = min_x + (c[1] + 0.5) * scale
            cy = min_y + (c[0] + 0.5) * scale
            return cx, cy

        start_cw = inst_centroid_world(start_inst)
        goal_cw = inst_centroid_world(goal_inst)
        if start_cw is not None:
            ax.text(start_cw[0], start_cw[1], f"START: {start_item}", color='yellow', fontsize=12,
                    ha='center', va='center', fontweight='bold')
        if goal_cw is not None:
            ax.text(goal_cw[0], goal_cw[1], f"GOAL: {end_item}", color='yellow', fontsize=12,
                    ha='center', va='center', fontweight='bold')

        # 绘制路径
        xs_w = [wp[0] for wp in world_path]
        ys_w = [wp[1] for wp in world_path]
        ax.plot(xs_w, ys_w, '-', color='red', linewidth=3, alpha=0.9)
        ax.scatter([xs_w[0], xs_w[-1]], [ys_w[0], ys_w[-1]], color='red', s=80)

        # 标注所选起终可行像素（绿色点）
        sx_w, sy_w = pix2world(start_px[0], start_px[1])
        gx_w, gy_w = pix2world(goal_px[0], goal_px[1])
        ax.scatter([sx_w, gx_w], [sy_w, gy_w], c=['lime', 'lime'], s=60, marker='o', edgecolors='k', linewidths=0.5)

        ax.set_title(f'2D Navigation Map - {os.path.basename(scene_dir)} | {traj_key}')
        ax.set_xlabel('X (meters)')
        ax.set_ylabel('Y (meters)')

        vis_path = os.path.join(vis_dir, f"{os.path.basename(scene_dir)}_{traj_key}_replan.png")
        plt.savefig(vis_path, dpi=150, bbox_inches='tight')
        plt.close(fig)

        traj_counter += 1

# ===== 组装输出JSON =====
out_json = {
    "dataset_metadata": {
        "name": "GVLN",
        "dataset_type": "dataset_type",
        "dataset_description": "dataset_description"
    },
    "scenes": [
        {
            "scene_id": 0,
            "scene_name": os.path.basename(scene_dir),
            "samples": samples_out
        }
    ]
}

# ===== 保存 =====
out_path = os.path.join(scene_dir, input_traj.replace('.json', '_recollected_gvln.json'))
with open(out_path, 'w') as f:
    json.dump(out_json, f, indent=2)

print(f"新轨迹数据已保存 -> {out_path}")
print(f"可视化图片已保存至 -> {vis_dir}")