import os
import numpy as np
import cv2
import matplotlib.pyplot as plt

root = "/path/VoxFormer-UQ/kitti/dataset/sequences"
output_root = "/path/VoxFormer/lidar_depth"
seq = "00"
calib_file = os.path.join(root, seq, 'calib.txt')
pointcloud_file = os.path.join(root, seq, 'velodyne', '000000.bin')
image_file = os.path.join(root, seq, 'image_2', '000000.png')

def read_calib_file(filepath):
    """读取并解析校准文件"""
    data = {}
    with open(filepath, 'r') as f:
        for line in f.readlines():
            key, value = line.split(':', 1)
            data[key] = np.array([float(x) for x in value.split()])
    return data

def depth_to_color(depth, min_depth, max_depth):
    """将深度值映射到颜色"""
    normalized_depth = (depth - min_depth) / (max_depth - min_depth)
    color = np.zeros((3,), dtype=np.uint8)
    color[0] = normalized_depth * 255  # R
    color[2] = (1 - normalized_depth) * 255  # B
    return color

# def project_velo_to_img(points, calib_data, img_shape):
#     Tr = np.eye(4)
#     Tr[:3, :] = np.reshape(calib_data['Tr'], (3, 4))
#     P = np.reshape(calib_data['P2'], (3, 4))
#     # P[:, 3] = 0
#     points_hom = np.hstack([points, np.ones((points.shape[0], 1))])
#     points_cam = np.dot(Tr, points_hom.T)
#     points_img = np.dot(P, points_cam)
#     points_img /= points_img[2, :]
#     # 筛选出在图像范围内的点，并返回深度值
#     valid = (points_img[0, :] >= 0) & (points_img[0, :] < img_shape[1]) & \
#             (points_img[1, :] >= 0) & (points_img[1, :] < img_shape[0]) & \
#             (points_cam[2, :] > 0)
#     points_img = points_img[:, valid]
#     depths = points_cam[2, valid]  # 深度值
#     return points_img[:2].T, depths
import numpy as np

def project_velo_to_img(points, calib_data, img_shape):
    Tr = np.eye(4)
    Tr[:3, :] = np.reshape(calib_data['Tr'], (3, 4))
    P = np.reshape(calib_data['P2'], (3, 4))
    
    points_hom = np.hstack([points, np.ones((points.shape[0], 1))])
    points_cam = np.dot(Tr, points_hom.T)
    
    # 计算到相机光心的距离，而不仅仅是z轴上的距离
    distances = np.sqrt(np.sum(points_cam[:3, :] ** 2, axis=0))
    
    points_img = np.dot(P, points_cam)
    points_img /= points_img[2, :]
    
    valid = (points_img[0, :] >= 0) & (points_img[0, :] < img_shape[1]) & \
            (points_img[1, :] >= 0) & (points_img[1, :] < img_shape[0]) & \
            (distances > 0)  & (points_cam[2, :] >0)# 使用距离过滤点
    
    points_img = points_img[:, valid]
    metric_depths = distances[valid]  # 过滤后的距离作为metric depth
    
    return points_img[:2].T, metric_depths


import numpy as np
import cv2
import os

# def interpolate_color(val, min_val, max_val):
#     """
#     根据深度值在彩虹色谱之间进行插值。
#     使用HSV色彩空间实现从红色开始到紫色的渐变效果。
#     返回值格式为(R, G, B)，其中R, G, B为整数，范围在0到255之间。
#     """
#     # 将深度值归一化到0到1之间
#     val_normalized = np.clip((val - min_val) / (max_val - min_val), 0, 1)
    
#     # 彩虹色谱在HSV空间中的色调从红色(0或360度)过渡到紫色(~270度)
#     # 我们可以将归一化的深度值映射到这个范围
#     # 注意：OpenCV的HSV色调范围是0到180，所以我们将值缩放到0到180
#     hue = (1.0 - val_normalized) * 170  # 从红色到蓝色的逆向映射

#     # 饱和度和亮度保持不变（最大）
#     saturation, value = 255, 255

#     # 将HSV颜色转换为RGB颜色
#     hsv_color = np.uint8([[[hue, saturation, value]]])
#     rgb_color = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2RGB)[0][0]

#     # 确保返回值格式为(R, G, B)，其中R, G, B为整数，范围在0到255之间
#     return (int(rgb_color[0]), int(rgb_color[1]), int(rgb_color[2]))

# def draw_colored_points(image, points, depths, min_depth, max_depth):
#     """
#     在图像上根据深度值绘制彩色点。
#     """
#     for (x, y), depth in zip(points, depths):
#         color = interpolate_color(depth, min_depth, max_depth)
#         cv2.circle(image, (int(x), int(y)), 1, color, -1)


def interpolate_color(index, total):
    """
    根据点的相对位置（而不是绝对深度）在彩虹色谱之间进行插值。
    """
    # 将点的索引归一化到0到1之间
    val_normalized = index / total
    
    # 使用HSV色彩空间进行颜色映射
    hue = (1.0 - val_normalized) * 170  # 从红色到蓝色的逆向映射
    saturation, value = 255, 255
    hsv_color = np.uint8([[[hue, saturation, value]]])
    rgb_color = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2RGB)[0][0]
    
    return (int(rgb_color[0]), int(rgb_color[1]), int(rgb_color[2]))

def draw_colored_points(image, points, depths, min_depth, max_depth):
    """
    根据点的数量分配颜色并在图像上绘制。
    """
    # 根据深度排序点（并获取排序后的索引）
    sorted_indices = np.argsort(depths)
    
    for index in range(len(points)):
        # 获取排序后的点位置
        sorted_index = sorted_indices[index]
        x, y = points[sorted_index]
        # 根据排序位置决定颜色
        color = interpolate_color(index, len(points))
        cv2.circle(image, (int(x), int(y)), 1, color, -1)



def depth_to_grayscale(depth, min_depth, max_depth):
    """将深度值转换为灰度值。"""
    depth_normalized = (depth - min_depth) / (max_depth - min_depth)
    grayscale = np.clip(depth_normalized * 255, 0, 255).astype(np.uint8)
    return 255 - grayscale  # 深的地方暗，浅的地方亮

def create_depth_image(points, depths, img_shape, min_depth, max_depth):
    """创建深度图像，未扫描到的点为255。"""
    depth_image = np.full((img_shape[0], img_shape[1]), -1, dtype=np.int8)
    
    for point, depth in zip(points, depths):
        x, y = int(point[0]), int(point[1])
        if 0 <= x < img_shape[1] and 0 <= y < img_shape[0]:
            # depth_image[y, x] = depth_to_grayscale(depth, min_depth, max_depth)
            depth_image[y, x] = depth
    
    return depth_image


def main():
    # 读取激光雷达点云数据、校准文件和图像
    root = "/path/VoxFormer-UQ/kitti/dataset/sequences"
    # seq = "00"
    # frame = 110

    for seq in [str(s_i).zfill(2) for s_i in range(11)]:
        frames=os.listdir(os.path.join(root, seq, 'velodyne'))
        #排序
        frames.sort()
        for frame in frames:
            frame=int(frame.split('.')[0])
            if frame!=90:
                continue



            calib_file = os.path.join(root, seq, 'calib.txt')
            pointcloud_file = os.path.join(root, seq, 'velodyne', str(frame).zfill(6)+'.bin')
            image_file = os.path.join(root, seq, 'image_2', str(frame).zfill(6)+'.png')

            pointcloud = np.fromfile(pointcloud_file, dtype=np.float32).reshape((-1, 4))
            image = cv2.imread(image_file)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            img_shape = image.shape

            # 假设project_velo_to_img和read_calib_file函数已经定义
            calib_data = read_calib_file(calib_file)
            img_points, depths = project_velo_to_img(pointcloud[:, :3], calib_data, img_shape)  # 使用你的函数

            min_depth, max_depth = np.min(depths), np.max(depths)
            draw_colored_points(image, img_points, depths, min_depth, max_depth)

            depth_image = create_depth_image(img_points, depths, img_shape, np.min(depths), np.max(depths))
            os.makedirs(os.path.join(output_root, seq, 'lidar_depth'), exist_ok=True)
            # np.save(os.path.join(output_root, seq, 'lidar_depth', str(frame).zfill(6)+'.npy'), depth_image)
            # 存压缩的npy文件
            # np.savez_compressed(os.path.join(output_root, seq, 'lidar_depth', str(frame).zfill(6)+'.npz'), depth_image=depth_image)
            print('Saved depth image to', os.path.join(output_root, seq, 'depth', str(frame).zfill(6)+'.png'))
            
            # 使用cv2.imwrite或类似方法保存图像
            # cv2.imshow('depth_image', depth_image)
            # cv2.waitKey(0)

    # cv2.imwrite('depth_image.png', depth_image)

    # # 使用matplotlib显示图像
            plt.imshow(image)
            plt.show()
            a=0

if __name__ == '__main__':
    main()
