import numpy as np
from numpy.linalg import inv
def ray_voxel_intersection(ray_start, ray_end):
    x_split = 128
    y_split = 128
    z_split = 16
    x_range = [0, 51.2]
    y_range = [-25.6, 25.6]
    z_range = [-2, 4.4]

    # 计算体素大小
    voxel_size = np.array([(x_range[1] - x_range[0]) / x_split, (y_range[1] - y_range[0]) / y_split, (z_range[1] - z_range[0]) / z_split])

    # 将射线起始和终止点转换为 NumPy 数组
    ray_start = np.array(ray_start)
    ray_end = np.array(ray_end)

    # 计算射线方向
    ray_dir = ray_end - ray_start

    # 创建边界数组
    x_bounds = np.arange(x_range[0], x_range[1] + voxel_size[0], voxel_size[0])
    y_bounds = np.arange(y_range[0], y_range[1] + voxel_size[1], voxel_size[1])
    z_bounds = np.arange(z_range[0], z_range[1] + voxel_size[2], voxel_size[2])

    points = []
    epsilon = 1e-10

    # 处理 x 轴上的边界
    for bounds in x_bounds:
        if not(bounds > min(ray_start[0], ray_end[0]) - epsilon and bounds < max(ray_start[0], ray_end[0]) + epsilon):
            continue
        t = (bounds - ray_start[0]) / ray_dir[0]
        y = ray_start[1] + ray_dir[1] * t
        z = ray_start[2] + ray_dir[2] * t
        if y > y_range[0] and y < y_range[1] and z > z_range[0] and z < z_range[1]:
            points.append(t)

    # 处理 y 轴上的边界
    for bounds in y_bounds:
        if not(bounds > min(ray_start[1], ray_end[1]) - epsilon and bounds < max(ray_start[1], ray_end[1]) + epsilon):
            continue
        t = (bounds - ray_start[1]) / ray_dir[1]
        x = ray_start[0] + ray_dir[0] * t
        z = ray_start[2] + ray_dir[2] * t
        # points.append(t)
        if x > x_range[0] and x < x_range[1] and z > z_range[0] and z < z_range[1]:
            points.append(t)

    # 处理 z 轴上的边界
    for bounds in z_bounds:
        if not(bounds > min(ray_start[2], ray_end[2]) - epsilon and bounds < max(ray_start[2], ray_end[2]) + epsilon):
            continue
        t = (bounds - ray_start[2]) / ray_dir[2]
        x = ray_start[0] + ray_dir[0] * t
        y = ray_start[1] + ray_dir[1] * t
        # points.append(t)
        if x > x_range[0] and x < x_range[1] and y > y_range[0] and y < y_range[1]:
            points.append(t)

    # 对交点按照 t 排序
    sorted_points = sorted(points)

    sorted_points = [ray_start[0], ray_start[1], ray_start[2], ray_dir[0], ray_dir[1], ray_dir[2]] + sorted_points

    # len_pad=300-len(sorted_points)
    # assert len_pad>0
    # sorted_points=sorted_points + sorted_points[-1]*(len_pad)

    return sorted_points



ray_end = (0, -25.6, -2.0)
ray_start = (51.2, 25.6, 4.4)
# ray_start = (0, 0, 0)
# ray_end = (51, 25, 4.0)

# 计算光线与更新后体素网格的精确交点
sorted_points = ray_voxel_intersection(ray_start, ray_end)
sorted_points[:10]  # 仅显示前10个交点的简略信息














a=0






































import numpy as np
import cv2
import os


a=0


import sys

# 检查是否有足够的参数传递给脚本
if len(sys.argv) == 2:
    # sys.argv[1]是传递给脚本的第一个参数
    param = sys.argv[1]
else: 
    raise NotImplementedError("Please provide a parameter")

down_rate=1
# /root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/preprocess_kitti360/kitti360/depth/sequences
root="/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/kitti360/data_2d_raw"
# seq="00"
seq=int(param)
image_path=os.path.join(root,"2013_05_28_drive_00"+str(seq).zfill(2)+"_sync","image_00",
                                "data_rect","000000.png")
# calib_path=os.path.join(root,"2013_05_28_drive_00"+str(seq).zfill(2)+"_sync","calib.txt")
calib_path=None


mean_tran=None



cnt=0

def read_calib(calib_path):
    """
    :param calib_path: Path to a calibration text file.
    :return: dict with calibration matrices.
    """
    # calib_all = {}
    # with open(calib_path, "r") as f:
    #     for line in f.readlines():
    #         if line == "\n":
    #             break
    #         key, value = line.split(":", 1)
    #         calib_all[key] = np.array([float(x) for x in value.split()])

    # reshape matrices
    calib_out = {}
    # 3x4 projection matrix for left camera
    # calib_out["P2"] = calib_all["P2"].reshape(3, 4)
    # calib_out["Tr"] = np.identity(4)  # 4x4 matrix
    # calib_out["Tr"][:3, :4] = calib_all["Tr"].reshape(3, 4)

    Tr = np.array([0.04307104361,-0.08829286498,0.995162929,0.8043914418,
                -0.999004371,0.007784614041,0.04392796942,0.2993489574,
                -0.01162548558,-0.9960641394,-0.08786966659,-0.1770225824,
                0,0,0,1])
    Tr = np.reshape(Tr, [4, 4])
    Tr = inv(Tr) #velo2cam

    calib_out["Tr"] =   Tr
    calib_out["P2"]=np.array([552.554261,0.000000,682.049453,0.000000,
                    0.000000,552.554261,238.769549,0.000000,
                    0.000000,0.000000,1.000000,0.000000]).reshape([3, 4])
    return calib_out

cam2img=np.eye(4)
cam2img[:3,:]=read_calib(calib_path)['P2']
velo2cam=read_calib(calib_path)['Tr']



img_tmp=cv2.imread(image_path)
img_tmp=cv2.resize(img_tmp,(img_tmp.shape[1]//down_rate,img_tmp.shape[0]//down_rate))
h=img_tmp.shape[0]
w=img_tmp.shape[1]

# h=img_tmp.shape[1]
# w=img_tmp.shape[0]

# img_vertices=np.array([[w_t, h_t, 1.0] for w_t in range(w) for h_t in range(h)], dtype=np.float32)
img_vertices=np.array([[w_t, h_t, 1.0] for h_t in range(h) for w_t in range(w) ], dtype=np.float32)

img_vertices=img_vertices.reshape(-1)
# print(img_vertices)
img_vertices=img_vertices.reshape(-1,3)
img_vertices=np.concatenate([img_vertices,np.ones_like(img_vertices[:,:1])],axis=1)


img_vertices=(np.linalg.inv(cam2img)@img_vertices.T).T
img_vertices=img_vertices[:,:3].reshape(-1).copy()

#15
# img_vertices = img_vertices.reshape(-1, 3)*30/64*4 # 将 img_vertices 重新形成一个二维数组，每个坐标三元组为一行
# img_vertices = img_vertices.reshape(-1, 3)*30 #可视化用
img_vertices = img_vertices.reshape(-1, 3)*60 # 计算光线与voxel交点用
##########################################################################################################
def project_image_to_rect(self_P, uv_depth):
    ''' Input: nx3 first two channels are uv, 3rd channel
                is depth in rect camera coord.
        Output: nx3 points in rect camera coord.
    '''

    # Camera intrinsics and extrinsics
    self_c_u = self_P[0, 2]
    self_c_v = self_P[1, 2]
    self_f_u = self_P[0, 0]
    self_f_v = self_P[1, 1]
    self_b_x = self_P[0, 3] / (-self_f_u)  # relative
    self_b_y = self_P[1, 3] / (-self_f_v)


    n = uv_depth.shape[0]
    x = ((uv_depth[:, 0] - self_c_u) * uv_depth[:, 2]) / self_f_u + self_b_x
    y = ((uv_depth[:, 1] - self_c_v) * uv_depth[:, 2]) / self_f_v + self_b_y
    pts_3d_rect = np.zeros((n, 3))
    pts_3d_rect[:, 0] = x
    pts_3d_rect[:, 1] = y
    pts_3d_rect[:, 2] = uv_depth[:, 2]
    return pts_3d_rect

# rows, cols = depth.shape
rows, cols = h, w
c, r = np.meshgrid(np.arange(cols), np.arange(rows))
depth=np.ones_like(c)*60
points = np.stack([c, r, depth])
points = points.reshape((3, -1))
points = points.T
# cloud = calib.project_image_to_velo(points)
pts_3d_rect=project_image_to_rect(read_calib(calib_path)['P2'], points)
pts_3d=np.transpose(np.dot(np.linalg.inv(np.eye(3)), np.transpose(pts_3d_rect)))
# n = pts_3d.shape[0]
# img_vertices = np.hstack((pts_3d, np.ones((n, 1))))
img_vertices = pts_3d


##########################################################################################################
zeros = np.zeros((img_vertices.shape[0], 3), dtype=img_vertices.dtype)
img_vertices = np.hstack((zeros.reshape(-1, 3), img_vertices)).reshape(-1,3)  # 在每行前面加上三个0
img_vertices = np.concatenate([img_vertices,np.ones_like(img_vertices[:,:1])],axis=1)
img_vertices=(np.linalg.inv(velo2cam)@img_vertices.T).T[:,:3]
img_vertices = np.array(img_vertices.tolist(), dtype=np.float32)



# ###############################################################################################################

img_tmp=cv2.imread(image_path)

img_tmp=cv2.resize(img_tmp,(img_tmp.shape[1]//down_rate,img_tmp.shape[0]//down_rate))
h=img_tmp.shape[0]
w=img_tmp.shape[1]


img_vertices=np.array(img_vertices.tolist(), dtype=np.float32)





# # 
# ##########################################################################################################################################################################
# 计算交点
img_vertices=img_vertices.reshape(w*h,2,3)
# for i in range(0,w*h,1):
save_root=r"/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/ray_voxel_intersection/kitti360"
os.makedirs(save_root, exist_ok=True)
from tqdm import tqdm
otpt=[]
for i in tqdm(range(0,w*h,1)):
    ray_start = img_vertices[i,0]
    ray_end = img_vertices[i,1]
    sorted_points = ray_voxel_intersection(ray_start, ray_end)
    # raise NotImplementedError(sorted_points)
    # np.save(os.path.join(save_root, str(i).zfill(8)+".npy"), sorted_points)
    otpt.append(sorted_points)

import pickle
# 使用pickle将列表保存到文件
with open(os.path.join(save_root, 'array_list.pkl'), 'wb') as f:
    pickle.dump(otpt, f)


# ##########################################################################################################################################################################
a=0
# 466616*128*128*16=

# 128*128*16*4=