import open3d as o3d
import numpy as np

def get_grid_coords(dims, resolution):
    """
    :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
    :return coords_grid: is the center coords of voxels in the grid
    """
    g_xx = np.arange(0, dims[0])
    g_yy = np.arange(0, dims[1])
    g_zz = np.arange(0, dims[2])

    # Obtaining the grid with coords...
    xx, yy, zz = np.meshgrid(g_xx, g_yy, g_zz)
    coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
    coords_grid = coords_grid.astype(np.float32)
    resolution = np.array(resolution, dtype=np.float32).reshape([1, 3])

    coords_grid = (coords_grid * resolution) + resolution / 2

    return coords_grid


def vis_voxels(input_file="labels.npy", output_file="voxels.ply"):
    # occ: [200, 200, 16]: 0-17
    vox_origin = [-40, -40, -1]
    voxel_size = [0.4] * 3

    voxels = np.load(input_file)
    # print("voxels:", voxels.shape, voxels.min(), voxels.max())
    # voxels = np.flip(voxels, axis=0)
    # voxels = np.flip(voxels, axis=1)

    w, h, z = voxels.shape

    # Compute the voxel coordinates
    grid_coords = get_grid_coords([w, h, z], voxel_size) + np.array(vox_origin, dtype=np.float32).reshape([1, 3])
    grid_coords = np.vstack([grid_coords.T, voxels.reshape(-1)]).T

    # Get the voxels inside FOV
    fov_grid_coords = grid_coords
    fov_voxels = fov_grid_coords[(fov_grid_coords[:, 3] > 0) & (fov_grid_coords[:, 3] < 17)]

    # Create Open3D point cloud
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(fov_voxels[:, :3])

    # Define a dictionary to map label to color
    label_to_color = {
        1: [255, 120, 50],     # barrier
        2: [255, 192, 203],    # bicycle
        3: [255, 255, 0],      # bus
        4: [0, 150, 245],      # car
        5: [0, 255, 255],      # construction_vehicle
        6: [255, 127, 0],      # motorcycle
        7: [255, 0, 0],        # pedestrian
        8: [255, 240, 150],    # traffic_cone
        9: [135, 60, 0],       # trailer
        10: [160, 32, 240],    # truck
        11: [255, 0, 255],     # driveable_surface
        12: [139, 137, 137],   # other_flat
        13: [75, 0, 75],       # sidewalk
        14: [150, 240, 80],    # terrain
        15: [230, 230, 250],   # manmade
        16: [0, 175, 0],       # vegetation
    }

    # Assign colors based on label
    label_colors = np.zeros((len(fov_voxels), 3))  
    for i in range(len(fov_voxels)):
        label = int(fov_voxels[i, 3])
        if label in label_to_color:
            label_colors[i] = label_to_color[label]

    pcd.colors = o3d.utility.Vector3dVector(label_colors / 255.0)

    # Save point cloud to file
    o3d.io.write_point_cloud(output_file, pcd)

    # Visualize point cloud (backend rendering, not inline)
    o3d.visualization.draw_geometries([pcd])
