import numpy as np
from scipy import ndimage
from sklearn.cluster import DBSCAN
import open3d as o3d
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import convolve2d
import os
def get_instruction(objects):
    # -*- coding: utf-8 -*-
    """
    批量生成 Minecraft setblock 指令
    """

    # 语义到 Minecraft 方块映射
    cls_real = ['empthy', 
                'ceiling', 
                'floor', 
                'wall', 
                'window', 
                'chair', 
                'bed', 
                'sofa', 
                'table', 
                'tvs', 
                'furniture', 
                'object']
    # cls_mc = [  'minecraft:air',
    #             'minecraft:stripped_acacia_wood[axis=y]',
    #             'tmeov:baisehunningtuqupiyunshangmushang',
    #             'minecraft:stripped_acacia_wood[axis=y]',
    #             'minecraft:air',
    #             'tmeo_ultra:huwaiyizi[facing=east]',
    #             'tmeo_ultra:bedblue_01[facing=north]',
    #             'minecraft:air',
    #             'tmeov:diannaozhuoyunshanmu[facing=south]',
    #             'tmeov:xianshiqidaiyuping_2guan[facing=south]',
    #             'minecraft:air',
    #             'minecraft:air'
    # ]
    cls_mc = [  'minecraft:air',
                'minecraft:stripped_acacia_wood[axis=y]',
                'tmeov:baisehunningtuqupiyunshangmushang',
                'minecraft:stripped_acacia_wood[axis=y]',
                'tmeov:baiyechuangheise_2x_2kai[facing=south,waterlogged=false]',
                "tmeo_ultra:canzhuoyizi[facing=south]",
                'tmeo_ultra:bedblue_01[facing=north]',
                "tmeov:shafabuliao_1x_1[facing=south]",
                "tmeov:chanzhuomuban[facing=south]",
                'minecraft:grass_block',
                'minecraft:grass_block',
                'minecraft:grass_block',
    ]

    print(len(cls_mc))
    print(len(cls_real))
    cls_inst = dict()
    for i in cls_real:
        cls_inst[i] = []
    
    import math

    def normal_round(x):
        return int(math.floor(x + 0.5))

    def cube_center_to_block_coord(cx, cy, cz):
        return normal_round(cx), normal_round(cy), normal_round(cz)

    orign_pos = (2302, 50, 982) 
    # 生成指令
    commands = []
    factor = 7
    for obj in objects:
        label = obj['category']
        # if label not in [2, 6]:
        #     continue
        cx, cz, cy = obj['centroid']

        cx /= factor
        cy /= factor
        cz /= factor

        bx, by, bz = cube_center_to_block_coord(cx, cy, cz)
        # print(bx, by, bz)
        bx += orign_pos[0]
        by += orign_pos[1]
        bz += orign_pos[2]
        block_type = cls_mc[label]
        cmd = f"setblock {bx} {by} {bz} {block_type}"
        # cmd = f"setblock {bx} {by} {bz} {'minecraft:air'}"

        commands.append(cmd)
        cls_inst[cls_real[label]].append(cmd)

    # # 输出到文件
    # with open(r"D:\minecraft\tmeo\.minecraft\versions\TMEOv7.1 NeoForge\saves\【地图】TMEOv7.1 forge版本通用测试地图\datapacks\myfunc\data\namespace\function\room2.mcfunction", "w", encoding="utf-8") as f:
    #     f.write("\n".join(commands))
    # 创建数据包目录结构
    datapack_path = r"D:\minecraft\tmeo\.minecraft\versions\TMEOv7.1 NeoForge\saves\【地图】TMEOv7.1 forge版本通用测试地图\datapacks\myfunc"
    data_path = os.path.join(datapack_path, "data", "namespace", "function")

    # 确保目录存在
    os.makedirs(data_path, exist_ok=True)

    # 定义类别执行顺序（根据您的实际情况调整）
    category_order = cls_real  # 示例类别顺序

    # 生成主函数文件，使用链式schedule调用
    main_function_path = os.path.join(data_path, "room2.mcfunction")
    with open(main_function_path, "w", encoding="utf-8") as f:
        f.write("# 开始建造房屋\n")
        f.write("say 开始建造房屋...\n")
        
        # 立即执行第一个类别，然后安排后续类别
        
        f.write(f"function namespace:{category_order[0]}\n")
        
        # 为后续类别安排延迟执行
        for i in range(1, len(category_order)):
            category = category_order[i]
            if category in cls_inst and cls_inst[category]:
                delay_seconds = i  # 每个类别延迟1秒递增
                f.write(f"schedule function namespace:{category} {delay_seconds}s\n")

    # 为每个类别生成单独的函数文件
    for category, commands in cls_inst.items():
        if commands:
            category_function_path = os.path.join(data_path, f"{category}.mcfunction")
            with open(category_function_path, "w", encoding="utf-8") as f:
                f.write(f"# 执行{category}类别的命令\n")
                f.write(f"say 开始建造{category}...\n")
                f.write("\n".join(commands))
                f.write(f"\nsay {category}建造完成！\n")

    print("生成完成！函数文件已保存到数据包目录")

# D:\minecraft\tmeo\.minecraft\versions\TMEOv7.1 NeoForge\datapacks\myfunc\data\namespace\function
# D:\minecraft\tmeo\.minecraft\versions\TMEOv7.1 NeoForge\saves\【地图】TMEOv7.1 forge版本通用测试地图\datapacks\myfunc\data\namespace\function\
import numpy as np
from scipy.ndimage import convolve


def preprocess(arr, window_size=5, threshold=3):
    """
    在3D数组上进行预处理：
    沿3个方向扫描，每个窗口检查不同数字个数，
    超过threshold则清零。
    
    arr: 输入三维数组 (numpy.ndarray)
    window_size: 窗口大小，默认 5
    threshold: 不同数字的阈值，默认 3
    """
    X, Y, Z = arr.shape
    result = arr.copy()
    half = window_size // 2
    
    # ---- 方向 1: XY 平面，沿 Z 扫描 ----
    for z in range(Z):
        for i in range(X - window_size + 1):
            for j in range(Y - window_size + 1):
                window = result[i:i+window_size, j:j+window_size, z]
                if len(np.unique(window)) > threshold:
                    result[i:i+window_size, j:j+window_size, z] = 0

    # ---- 方向 2: YZ 平面，沿 X 扫描 ----
    for x in range(X):
        for i in range(Y - window_size + 1):
            for j in range(Z - window_size + 1):
                window = result[x, i:i+window_size, j:j+window_size]
                if len(np.unique(window)) > threshold:
                    result[x, i:i+window_size, j:j+window_size] = 0

    # ---- 方向 3: XZ 平面，沿 Y 扫描 ----
    for y in range(Y):
        for i in range(X - window_size + 1):
            for j in range(Z - window_size + 1):
                window = result[i:i+window_size, y, j:j+window_size]
                if len(np.unique(window)) > threshold:
                    result[i:i+window_size, y, j:j+window_size] = 0
    
    return result

def merge_points(points, eps=0.5):
    """
    合并距离小于 eps 的 3D 点
    :param points: (N, 3) numpy 数组
    :param eps: 距离阈值
    :return: 合并后的点 (取簇的均值)
    """
    points = np.array(points)
    clustering = DBSCAN(eps=eps, min_samples=1).fit(points)
    
    merged = []
    for label in set(clustering.labels_):
        cluster_points = points[clustering.labels_ == label]
        merged.append(cluster_points.mean(axis=0))  # 用均值代表簇
    
    return np.array(merged)



def myconv(arr):
    # arr = np.random.randint(0, 2, (50, 50, 50))  # 0/1 表示是否有物体

    kernel_size = 7
    stride = 7
    threshold = 20

    # 10x10 全1卷积核
    kernel = np.ones((kernel_size, kernel_size), dtype=int)

    occupied_coords = []

    X, Y, Z = arr.shape

    # # ---- 方向 1: XY 平面，沿 Z 扫描 ----
    # for z in range(Z):
    #     conv_map = convolve2d(arr[:, :, z], kernel, mode='valid')
    #     for i in range(0, conv_map.shape[0], stride):
    #         for j in range(0, conv_map.shape[1], stride):
    #             if conv_map[i, j] > threshold:
    #                 occupied_coords.append((i, j, z))  # XY 面的位置

    # # ---- 方向 2: YZ 平面，沿 X 扫描 ----
    # for x in range(X):
    #     conv_map = convolve2d(arr[x, :, :], kernel, mode='valid')
    #     for i in range(0, conv_map.shape[0], stride):
    #         for j in range(0, conv_map.shape[1], stride):
    #             if conv_map[i, j] > threshold:
    #                 occupied_coords.append((x, i, j))  # YZ 面的位置

    # # ---- 方向 3: XZ 平面，沿 Y 扫描 ----
    # for y in range(Y):
    #     conv_map = convolve2d(arr[:, y, :], kernel, mode='valid')
    #     for i in range(0, conv_map.shape[0], stride):
    #         for j in range(0, conv_map.shape[1], stride):
    #             if conv_map[i, j] > threshold:
    #                 occupied_coords.append((i, y, j))  # XZ 面的位置
    # ---- 方向 1: XY 平面，沿 Z 扫描 ----
    for z in range(Z):
        conv_map = convolve2d(arr[:, :, z], kernel, mode='valid')
        for i in range(0, conv_map.shape[0], stride):
            for j in range(0, conv_map.shape[1], stride):
                if conv_map[i, j] > threshold:
                    # XY 面卷积核中心坐标
                    center_i = i + kernel_size // 2
                    center_j = j + kernel_size // 2
                    occupied_coords.append((center_i, center_j, z))

    # ---- 方向 2: YZ 平面，沿 X 扫描 ----
    for x in range(X):
        conv_map = convolve2d(arr[x, :, :], kernel, mode='valid')
        for i in range(0, conv_map.shape[0], stride):
            for j in range(0, conv_map.shape[1], stride):
                if conv_map[i, j] > threshold:
                    # YZ 面卷积核中心坐标
                    center_i = i + kernel_size // 2
                    center_j = j + kernel_size // 2
                    occupied_coords.append((x, center_i, center_j))

    # ---- 方向 3: XZ 平面，沿 Y 扫描 ----
    for y in range(Y):
        conv_map = convolve2d(arr[:, y, :], kernel, mode='valid')
        for i in range(0, conv_map.shape[0], stride):
            for j in range(0, conv_map.shape[1], stride):
                if conv_map[i, j] > threshold:
                    # XZ 面卷积核中心坐标
                    center_i = i + kernel_size // 2
                    center_j = j + kernel_size // 2
                    occupied_coords.append((center_i, y, center_j))
    return occupied_coords

predefined_colors = np.array([
            [22, 191, 206, 255],
            [214, 38, 40, 255],
            [43, 160, 43, 255],
            [158, 216, 229, 255],
            [114, 158, 206, 255],
            [204, 204, 91, 255],
            [255, 186, 119, 255],
            [147, 102, 188, 255],
            [30, 119, 181, 255],
            [188, 188, 33, 255],
            [255, 127, 12, 255],
            [196, 175, 214, 255],
            [153, 153, 153, 255],
            [0, 0, 0, 255],
        ])  # 保持原颜色
predefined_colors = predefined_colors[:, :3] / 255.0
def get_grid_coords(dims, resolution):
    """
    :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
    :return coords_grid: is the center coords of voxels in the grid
    """

    g_xx = np.arange(0, dims[0] + 1)
    g_yy = np.arange(0, dims[1] + 1)
    g_zz = np.arange(0, dims[2] + 1)

    # Obtaining the grid with coords...
    xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1])
    coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
    coords_grid = coords_grid.astype(np.float32)
    
    # coords_grid = (coords_grid * resolution) + resolution / 2

    temp = np.copy(coords_grid)
    temp[:, 0] = coords_grid[:, 1]
    temp[:, 1] = coords_grid[:, 0]
    coords_grid = np.copy(temp)

    return coords_grid

def cluster_objects(occupancy_grid, min_voxels=5, eps=5, min_samples=50):

    object_mask = occupancy_grid > 0
    object_coords = np.argwhere(object_mask)
    
    if len(object_coords) == 0:
        return []
    
    object_categories = occupancy_grid[object_mask]
    
    all_objects = []
    
    unique_categories = np.unique(object_categories)

    for category in unique_categories:
        if category == 0:
            continue
        elif category in [1, 2]:  
            temp = np.isin(occupancy_grid, [category]).astype(int)
            temp_obj = myconv(temp)
            for i in temp_obj:
                all_objects.append({
                'size':50,
                'centroid': tuple(i),
                'category': int(category),
            })
        elif category == 3:
            temp = np.isin(occupancy_grid, [category]).astype(int)
            temp_obj = myconv(temp)
            # print(len(temp_obj))
            temp_obj = merge_points(temp_obj, eps=1)
            # print(len(rm_point))
            for i in temp_obj:
                all_objects.append({
                'size':50,
                'centroid': tuple(i),
                'category': int(category),
            })
        else:
            # continue
            cat_mask = (occupancy_grid == category)
            cat_coords = np.argwhere(cat_mask)
            
            # for i in cat_coords:
            #     print(i)
            if len(cat_coords) < min_voxels:
                continue
            
            labeled_array, num_features = ndimage.label(cat_mask)
            for obj_id in range(1, num_features + 1):
                obj_mask = (labeled_array == obj_id)
                obj_coords = np.argwhere(obj_mask)
                
                centroid = np.mean(obj_coords, axis=0)
                min_coord = np.min(obj_coords, axis=0)
                max_coord = np.max(obj_coords, axis=0)
                
                all_objects.append({
                    'centroid': tuple(centroid),
                    'category': int(category),
                    'size': len(obj_coords),
                    'bbox': [int(min_coord[0]), int(max_coord[0]), 
                            int(min_coord[1]), int(max_coord[1]),
                            int(min_coord[2]), int(max_coord[2])],
                    'coords': obj_coords  
                })
    return all_objects
    
def visualize_clusters(occupancy_grid, objects, resolution=0.08):  # 添加resolution参数
    points = []
    colors = []
    geometries = []
    
    for x in range(occupancy_grid.shape[0]):
        for y in range(occupancy_grid.shape[1]):
            for z in range(occupancy_grid.shape[2]):
                val = occupancy_grid[x, y, z]
                if val <= 0 or val >= 255:
                    continue
                points.append([x, y, z])
                colors.append(predefined_colors[val])

    # for idx, obj in enumerate(objects):
    #     if obj['size'] < filter_num:
    #         continue
                
    #     # voxel_coords = obj['coords']  
        
    #     sphere = o3d.geometry.TriangleMesh.create_sphere(radius=1)
    #     sphere.compute_vertex_normals()
    #     sphere.paint_uniform_color([0, 0, 0])  
    #     sphere.translate(obj['centroid'])

    #     geometries.append(sphere)

        # for i in range(len(voxel_coords)):
        #     points.append([voxel_coords[i, 0], voxel_coords[i, 1], voxel_coords[i, 2]])  # 注意：wy在前，wx在后
        #     colors.append((1, 0, 0))



    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.array(points))
    pcd.colors = o3d.utility.Vector3dVector(np.array(colors))
    geometries.append(pcd)
    
    o3d.visualization.draw_geometries(geometries)

def visualize2(occupancy_grid, objects, resolution=0.08):  # 添加resolution参数
    points = []
    colors = []
    geometries = []
    
    for x in range(occupancy_grid.shape[0]):
        for y in range(occupancy_grid.shape[1]):
            for z in range(occupancy_grid.shape[2]):
                val = occupancy_grid[x, y, z]
                if val <= 0 or val >= 255:
                    continue
                points.append([x+100, z, y])
                colors.append(predefined_colors[val])

    cube_size = 10  # 立方体边长
    half_size = cube_size // 2  # 一半长度

    for idx, obj in enumerate(objects):
        cx, cy, cz = obj['centroid']

        # 遍历 x, y, z 范围
        # for dx in range(-half_size, half_size):
        #     for dy in range(-half_size, half_size):
        #         for dz in range(-half_size, half_size):
        #             # points.append([cx + dx, cy + dy, cz + dz])
        #             points.append([cx + dx, cz + dz, cy + dy])
        #             colors.append(predefined_colors[obj['category']])
        sphere = o3d.geometry.TriangleMesh.create_sphere(radius=1)
        sphere.compute_vertex_normals()
        sphere.paint_uniform_color([0, 0, 0]) 
        transfer_centroid = obj['centroid'][0],obj['centroid'][2],obj['centroid'][1]
        sphere.translate(transfer_centroid)
        geometries.append(sphere)
    for x in range(occupancy_grid.shape[0]):
        for y in range(occupancy_grid.shape[1]):
            for z in range(occupancy_grid.shape[2]):
                val = occupancy_grid[x, y, z]
                if val <= 0 or val >= 255:
                    continue
                points.append([x, z, y])
                colors.append(predefined_colors[val])



    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.array(points))
    pcd.colors = o3d.utility.Vector3dVector(np.array(colors))
    geometries.append(pcd)
    # 添加坐标系，size 控制坐标轴长度
    axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=50, origin=[0, 0, 0])
    geometries.append(axis)
    o3d.visualization.draw_geometries(geometries)

# 示例用法
if __name__ == "__main__":
    import json
    scene = np.load(r'D:\ecnu\2025\gt_visualizations\scene0006_02\full_scene.npy')
    # scene = np.load(r'realsim\res\scene0006_02_label.npy')
    # scene = np.load(r'D:\ecnu\2025\realsim\predict.npy')
    scene[scene == 12] = 0
    # scene = preprocess(scene, window_size=5, threshold=4)
    # print(scene.shape)
    occupancy_grid = scene
    resolution = 0.08
    grid_coords = get_grid_coords(
        [occupancy_grid.shape[0], occupancy_grid.shape[1], occupancy_grid.shape[2]], resolution
    )
    # print(occupancy_grid.shape)

    grid_coords = np.vstack(
        (grid_coords.T, occupancy_grid.reshape(-1))
    ).T

    occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)]

    new_occupied_voxels = np.zeros_like(occupancy_grid)
    for i in range(len(occupied_voxels)):
        val = occupied_voxels[i, 3]
        a,b,c = int(occupied_voxels[i, 0]), int(occupied_voxels[i, 1]), int(occupied_voxels[i, 2])
        new_occupied_voxels[b, a, c] = val
    points = []
    X, Y, Z = new_occupied_voxels.shape
    for x in range(X):
        for y in range(Y):
            for z in range(Z):
                c = int(new_occupied_voxels[x, y, z])
                if c != 0:  # 假设 0 表示背景/空
                    points.append({"x": x, "y": y, "z": z, "category": c})

    with open(r"D:\ecnu\2025\realsim\occ.json", "w") as f:
        json.dump(points, f)

    objects = cluster_objects(new_occupied_voxels, min_voxels=50, eps=5.0, min_samples=50)
    filter_num = 50
    total = 0
    points = []
    for i, obj in enumerate(objects):
        if obj['size'] < filter_num:
            continue
        total += 1
        x,y,z = int(obj['centroid'][0]), int(obj['centroid'][1]), int(obj['centroid'][2])
        c = int(obj['category'])
        points.append({"x": x, "y": y, "z": z, "category": c})
        # print(f"物体 {i+1}:")
        # print(f"  类别: {obj['category']}")
        # print(f"  中心: ({obj['centroid'][0]:.1f}, {obj['centroid'][1]:.1f}, {obj['centroid'][2]:.1f})")
        # print(f"  大小: {obj['size']} 个体素")
        # print(f"  边界框: X[{obj['bbox'][0]}-{obj['bbox'][1]}], " +
        #       f"Y[{obj['bbox'][2]}-{obj['bbox'][3]}], " +
        #       f"Z[{obj['bbox'][4]}-{obj['bbox'][5]}]")
    with open(r"D:\ecnu\2025\realsim\point.json", "w") as f:
        json.dump(points, f)
    print(f"总物体数: {total}")
    # visualize_clusters(new_occupied_voxels, objects)
    visualize2(new_occupied_voxels, objects)
    get_instruction(objects)
    # final_point = []
    # for i, obj in enumerate(objects):
    #     a,b,c = int(obj['centroid'][0] / 10 + 0.5), int(obj['centroid'][1] / 10 + 0.5), int(obj['centroid'][2] / 10 + 0.5)
    #     final_point.append((a, b, c))
    
    # block_size = 1

    # # 计算网格大小
    # xs, ys, zs = zip(*final_point)
    # grid_shape = (max(xs)+block_size+1, max(ys)+block_size+1, max(zs)+block_size+1)

    # # 初始化网格
    # grid = np.zeros(grid_shape, dtype=bool)

    # # 将每个点扩展成立方体
    # for x, y, z in final_point:
    #     grid[x:x+block_size, y:y+block_size, z:z+block_size] = True

    # # 可视化
    # fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    # X, Y, Z = np.indices((grid.shape[0]+1, grid.shape[1]+1, grid.shape[2]+1))
    # ax.voxels(X, Y, Z, grid, facecolors='orange', edgecolor='k')

    # ax.set_xlabel('X')
    # ax.set_ylabel('Y')
    # ax.set_zlabel('Z')
    # ax.set_box_aspect([1,1,1])  # 保持立方体比例
    # plt.show()



'''
D:\minecraft\tmeo\.minecraft\versions\tmeo7.1\saves\【地图】TMEOv7.1 forge版本通用测试地图
D:\minecraft\tmeo\.minecraft\versions\tmeo7.1

'''


# import cv2
# import os
# import glob

# # ===== 配置部分 =====
# image_folder = r'D:\ecnu\2025\gt_visualizations\scene0024_02'       # 图片文件夹路径
# output_video = "output.mp4"   # 输出视频文件名
# fps = 3                      # 帧率，可根据需要修改

# # ===== 获取图片列表 =====
# temp = glob.glob(os.path.join(image_folder, "*.png"))
# temp = [i for i in temp if 'gt.' not in i]
# print(len(temp))

# images = sorted(temp)  

# print(len(images))
# # 读取第一张图片获取尺寸
# frame = cv2.imread(images[0])
# height, width, _ = frame.shape

# # 定义视频写入器
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # 保存为 mp4 格式
# out = cv2.VideoWriter(output_video, fourcc, fps, (width, height))

# # ===== 遍历图片写入视频 =====
# for img_path in images:
#     img = cv2.imread(img_path)
    
#     if img is None:
#         continue
#     # 水平镜像
#     flipped = cv2.flip(img, 1)
#     cv2.imwrite(r'D:\ecnu\2025\gt_visualizations\img\\' + os.path.basename(img_path), flipped)
#     out.write(flipped)

# out.release()
# print(f"视频已保存到 {output_video}")


