import numpy as np
import json
import os
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy.ndimage import label as nd_label

scene_dir = "test/839919"
occ_img_path = os.path.join(scene_dir, "occupancy.png")
occ_json_path = os.path.join(scene_dir, "occupancy.json")

# ==== 读取元信息 ====
with open(occ_json_path) as f:
    meta = json.load(f)

scale = meta["scale"]
x_min, y_min = meta["min"][0], meta["min"][1]
h, w = Image.open(occ_img_path).size[::-1]  # height, width

# ==== 读取 occupancy.png ====
occ_img = Image.open(occ_img_path).convert("L")
occ = np.array(occ_img)

print(f"Occupancy图像分析:")
print(f"  尺寸: {occ.shape}")
print(f"  灰度值范围: {occ.min()} - {occ.max()}")

# ==== 严格分离区域 ====
walkable_mask = (occ == 255)  # 白色 - 可行走
object_mask = (occ == 0)  # 黑色 - 物体
wall_mask = (occ > 0) & (occ < 255)  # 灰色 - 墙体

print(f"\n区域统计:")
print(f"  可行走区域: {walkable_mask.sum()} 像素")
print(f"  物体区域: {object_mask.sum()} 像素")
print(f"  墙体区域: {wall_mask.sum()} 像素")

# ==== 使用最严格的连通性分析 ====
# 只有直接相邻才算连接
strict_structure = np.array([[0, 1, 0],
                             [1, 1, 1],
                             [0, 1, 0]], dtype=np.int32)  # 4-connectivity

# 或者使用更严格的1-connectivity（只上下左右）
ultra_strict = np.array([[0, 1, 0],
                         [1, 1, 1],
                         [0, 1, 0]], dtype=np.int32)

result_list = []
fake_categories = ["door", "table", "chair", "wall", "sofa", "window", "cabinet", "bed", "desk", "lamp"]

print(f"\n=== 开始精细分割 ===")

# ==== 1. 处理墙体区域 ====
if wall_mask.sum() > 0:
    print("处理墙体区域...")
    wall_components, wall_num = nd_label(wall_mask, structure=strict_structure)
    print(f"墙体分割为 {wall_num} 个组件")

    for i in range(1, wall_num + 1):
        component_mask = (wall_components == i)
        area = component_mask.sum()

        if area < 2:  # 极小区域过滤
            continue

        ys, xs = np.where(component_mask)
        print(f"  墙体组件 {i}: {area} 像素, {len(xs)} 个点")

        if len(xs) == 0:
            continue

        # 计算精确bbox
        x_min_pix = xs.min()
        x_max_pix = xs.max()
        y_min_pix = ys.min()
        y_max_pix = ys.max()

        # 转换坐标
        x_left = x_min + x_min_pix * scale
        x_right = x_min + x_max_pix * scale
        y_bottom = y_min + y_min_pix * scale
        y_top = y_min + y_max_pix * scale

        width = x_right - x_left
        height = y_top - y_bottom

        bbox_m = [f"{x_left:.2f}", f"{y_bottom:.2f}", f"{x_right:.2f}", f"{y_top:.2f}"]
        bbox_xywh_m = [f"{x_left:.2f}", f"{y_bottom:.2f}", f"{width:.2f}", f"{height:.2f}"]

        # mask坐标
        mask_coords_m = []
        for y, x in zip(ys, xs):
            x_world = x_min + x * scale
            y_world = y_min + y * scale
            mask_coords_m.append([f"{x_world:.2f}", f"{y_world:.2f}"])

        result_list.append({
            "category_id": 4,
            "category_label": "wall",
            "instance_id": f"wall_{i}",
            "bbox_m": bbox_m,
            "bbox_xywh_m": bbox_xywh_m,
            "area": int(area),
            "mask_coords_m": mask_coords_m
        })

# ==== 2. 处理物体区域 ====
if object_mask.sum() > 0:
    print("处理物体区域...")
    object_components, object_num = nd_label(object_mask, structure=strict_structure)
    print(f"物体分割为 {object_num} 个组件")

    for i in range(1, object_num + 1):
        component_mask = (object_components == i)
        area = component_mask.sum()

        if area < 1:  # 几乎不过滤
            continue

        ys, xs = np.where(component_mask)
        print(f"  物体组件 {i}: {area} 像素, {len(xs)} 个点")

        if len(xs) == 0:
            continue

        # 计算精确bbox
        x_min_pix = xs.min()
        x_max_pix = xs.max()
        y_min_pix = ys.min()
        y_max_pix = ys.max()

        # 转换坐标
        x_left = x_min + x_min_pix * scale
        x_right = x_min + x_max_pix * scale
        y_bottom = y_min + y_min_pix * scale
        y_top = y_min + y_max_pix * scale

        width = x_right - x_left
        height = y_top - y_bottom

        bbox_m = [f"{x_left:.2f}", f"{y_bottom:.2f}", f"{x_right:.2f}", f"{y_top:.2f}"]
        bbox_xywh_m = [f"{x_left:.2f}", f"{y_bottom:.2f}", f"{width:.2f}", f"{height:.2f}"]

        # mask坐标
        mask_coords_m = []
        for y, x in zip(ys, xs):
            x_world = x_min + x * scale
            y_world = y_min + y * scale
            mask_coords_m.append([f"{x_world:.2f}", f"{y_world:.2f}"])

        # 分配类别
        non_wall_count = len([r for r in result_list if r["category_label"] != "wall"])
        cat_index = non_wall_count % (len(fake_categories) - 1)  # 不包括wall
        non_wall_cats = [cat for cat in fake_categories if cat != "wall"]
        category_label = non_wall_cats[cat_index] if non_wall_cats else "object"
        cat_id = fake_categories.index(category_label) + 1

        result_list.append({
            "category_id": int(cat_id),
            "category_label": category_label,
            "instance_id": f"object_{i}",
            "bbox_m": bbox_m,
            "bbox_xywh_m": bbox_xywh_m,
            "area": int(area),
            "mask_coords_m": mask_coords_m
        })

print(f"\n总共处理了 {len(result_list)} 个物体")

# ==== 调试可视化 ====
plt.figure(figsize=(20, 15))

# 显示原始图像
plt.subplot(2, 3, 1)
plt.imshow(occ, cmap='gray', origin='upper')
plt.title('Original Occupancy')

# 显示墙体分割
plt.subplot(2, 3, 2)
if wall_mask.sum() > 0:
    wall_components, _ = nd_label(wall_mask, structure=strict_structure)
    plt.imshow(wall_components, cmap='tab20', origin='upper')
    plt.title('Wall Components')
else:
    plt.imshow(np.zeros_like(occ), origin='upper')
    plt.title('No Wall Components')

# 显示物体分割
plt.subplot(2, 3, 3)
if object_mask.sum() > 0:
    object_components, _ = nd_label(object_mask, structure=strict_structure)
    plt.imshow(object_components, cmap='tab20', origin='upper')
    plt.title('Object Components')
else:
    plt.imshow(np.zeros_like(occ), origin='upper')
    plt.title('No Object Components')

# 显示所有分割结果
plt.subplot(2, 3, 4)
# 创建彩色分割图
segmentation_map = np.zeros((h, w), dtype=np.int32)
for i, inst in enumerate(result_list):
    for coord in inst["mask_coords_m"][:3]:  # 取前几个点
        try:
            x_world, y_world = float(coord[0]), float(coord[1])
            x_pix = round((x_world - x_min) / scale)
            y_pix = round((y_world - y_min) / scale)
            if 0 <= x_pix < w and 0 <= y_pix < h:
                segmentation_map[y_pix, x_pix] = i + 1
                break
        except:
            continue

# 传播颜色
for i, inst in enumerate(result_list):
    for coord in inst["mask_coords_m"]:
        try:
            x_world, y_world = float(coord[0]), float(coord[1])
            x_pix = round((x_world - x_min) / scale)
            y_pix = round((y_world - y_min) / scale)
            if 0 <= x_pix < w and 0 <= y_pix < h:
                segmentation_map[y_pix, x_pix] = i + 1
        except:
            continue

plt.imshow(segmentation_map, cmap='tab20', origin='upper')
plt.title(f'All Segments ({len(result_list)})')

# 显示带bbox的最终结果
plt.subplot(2, 3, 5)
plt.imshow(segmentation_map, cmap='tab20', origin='upper')

# 绘制所有bbox
bbox_count = 0
for inst in result_list:
    try:
        coords = inst["mask_coords_m"]
        if len(coords) == 0:
            continue

        x_coords = [float(c[0]) for c in coords]
        y_coords = [float(c[1]) for c in coords]

        x_min_world = min(x_coords)
        x_max_world = max(x_coords)
        y_min_world = min(y_coords)
        y_max_world = max(y_coords)

        width_world = x_max_world - x_min_world
        height_world = y_max_world - y_min_world

        if width_world > 0 and height_world > 0:
            # 转换到像素坐标
            x_pix = (x_min_world - x_min) / scale
            y_pix = (y_min_world - y_min) / scale
            w_pix = width_world / scale
            h_pix = height_world / scale

            rect = plt.Rectangle((x_pix, y_pix), w_pix, h_pix,
                                 edgecolor='red', facecolor='none', linewidth=1.0)
            plt.gca().add_patch(rect)
            bbox_count += 1

            # 添加标签
            if bbox_count <= 30:
                label = f"{inst['instance_id']}"
                plt.text(x_pix, y_pix, label, fontsize=6, color='white',
                         backgroundcolor='red', verticalalignment='bottom')

    except Exception as e:
        print(f"绘制bbox错误 {inst['instance_id']}: {e}")

plt.title(f'With BBoxes ({bbox_count} boxes)')

# 显示统计信息
plt.subplot(2, 3, 6)
areas = [inst["area"] for inst in result_list] if result_list else [0]
if areas:
    plt.hist(areas, bins=min(30, len(areas)), alpha=0.7)
    plt.xlabel('Area (pixels)')
    plt.ylabel('Count')
    plt.title(f'Area Distribution ({len(areas)} objects)')
else:
    plt.text(0.5, 0.5, 'No objects found', ha='center', va='center')
    plt.title('Area Distribution')

plt.tight_layout()
plt.savefig(os.path.join(scene_dir, "detailed_analysis.png"), dpi=300, bbox_inches='tight')
plt.show()

# ==== 保存结果 ====
out_json_path = os.path.join(scene_dir, "839919_occupancy_objects.json")
with open(out_json_path, "w") as f:
    json.dump(result_list, f, indent=2)
print(f"\n结果已保存到: {out_json_path}")

# ==== 最终可视化 ====
if result_list:
    plt.figure(figsize=(12, 12))
    plt.imshow(segmentation_map, cmap='tab20', origin='upper')

    # 绘制bbox
    final_bbox_count = 0
    for inst in result_list[:100]:  # 限制绘制数量
        try:
            coords = inst["mask_coords_m"]
            x_coords = [float(c[0]) for c in coords]
            y_coords = [float(c[1]) for c in coords]

            x_min_world = min(x_coords)
            x_max_world = max(x_coords)
            y_min_world = min(y_coords)
            y_max_world = max(y_coords)

            x_pix = (x_min_world - x_min) / scale
            y_pix = (y_min_world - y_min) / scale
            w_pix = (x_max_world - x_min_world) / scale
            h_pix = (y_max_world - y_min_world) / scale

            rect = plt.Rectangle((x_pix, y_pix), w_pix, h_pix,
                                 edgecolor='red', facecolor='none', linewidth=0.8)
            plt.gca().add_patch(rect)
            final_bbox_count += 1

        except Exception as e:
            continue

    plt.title(f'Final Result ({len(result_list)} objects, {final_bbox_count} boxes shown)')
    plt.savefig(os.path.join(scene_dir, "839919_occupancy_objects.png"), dpi=300, bbox_inches='tight')
    plt.show()

# ==== 统计信息 ====
if result_list:
    wall_count = len([r for r in result_list if r["category_label"] == "wall"])
    object_count = len([r for r in result_list if r["category_label"] != "wall"])

    print(f"\n=== 最终统计 ===")
    print(f"墙体: {wall_count}")
    print(f"物体: {object_count}")
    print(f"总计: {len(result_list)}")

    areas = [inst["area"] for inst in result_list]
    print(f"面积范围: {min(areas)} - {max(areas)} (平均: {np.mean(areas):.1f})")