"""
从 Waymo 原始 TFRecord 直接读取图像、LiDAR 和相机位姿，按统一下采样倍率同步采样后做深度补全推理。

用法示例：
  python demo_waymo_from_tfrecord.py

注意：参数在代码开头硬编码配置：
  - INPUT_TFRECORD_DIR: TFRecord 文件目录路径，将遍历其中所有 .tfrecord 文件
  - DOWNSAMPLE: 统一下采样倍率（30Hz -> 10Hz）。仅处理 frame_idx % downsample == 0 的帧
  - OUTPUT_ROOT: 输出目录

相机：固定处理 pinhole 下的 5 个相机：front、front_left、front_right、side_left、side_right
LiDAR：读取 TOP 雷达点云（与 convert_waymo_to_rds_hq.py 保持一致），并投影到各相机图像上以构建稀疏深度
"""

import numpy as np
import cv2
import torch
import hydra
import tensorflow as tf
from pathlib import Path
from tqdm import tqdm
import gzip
import pickle
import ray
from concurrent.futures import ThreadPoolExecutor, as_completed

# Waymo API
from waymo_open_dataset import dataset_pb2

# Waymo v2 与 v2 工具
import transforms3d
from copy import deepcopy
from waymo_open_dataset.v2 import column_types
from waymo_open_dataset.v2.perception import (
    lidar as _v2_lidar,
    context as _v2_context,
    pose as _v2_pose,
    base as _v2_base,
)
from waymo_open_dataset.v2.perception.utils.lidar_utils import (
    convert_range_image_to_cartesian as v2_convert_range_image_to_cartesian,
    parse_range_image as v2_parse_range_image,
)

# ========== 配置参数 ==========
DOWNSAMPLE = 3
PARTITION = "training"
INPUT_TFRECORD_DIR = f"/data/dataset/waymo/{PARTITION}/"
OUTPUT_ROOT = f"/data/dataset/waymo_processed_my/pointcloud/{PARTITION}/"
SKY_MASK_DIR = f"/data/dataset/waymo_processed_my/sky_masks/{PARTITION}/"
SKY_DEPTH_VALUE = 85.0  # 天空区域的深度值（米）
MASK_BLUR_SIGMA = 7.0   # 边缘软化的高斯模糊sigma值
MASK_ERODE_KERNEL = 3  # 形态学腐蚀的核大小，用于略微缩小天空区域
VISUALIZE = False  # False: 不保存任何图片，仅保存点云 npz

# 点云存储格式配置
POINT_CLOUD_STORAGE_FORMAT = "npz_fp16"  # 可选: "npz" (压缩numpy), "npz_bf16" (bf16+npz压缩，最小), "bfloat16" (bf16精度), "ply" (标准PLY)
POINT_CLOUD_COMPRESSION = True  # 是否使用压缩
REMOVE_OVERLAP = False  # 关闭时将保留所有相机重叠区域的点云；开启时会根据相机优先级对低优先级相机进行“挖洞”。
# ============================

# 并行配置
CPU_WORKERS = 48

def camera_name_from_enum(name_int: int) -> str:
    return dataset_pb2.CameraName.Name.Name(name_int).lower()


def lidar_name_from_enum(name_int: int) -> str:
    return dataset_pb2.LaserName.Name.Name(name_int)


def build_camera_intrinsics_from_calib(camera_calib) -> np.ndarray:
    intrinsic = camera_calib.intrinsic  # [fx, fy, cx, cy, k1, k2, p1, p2, k3]
    fx, fy, cx, cy = intrinsic[:4]
    K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
    return K


def compute_camera_to_world_opencv(vehicle_to_world: np.ndarray, camera_to_vehicle: np.ndarray) -> np.ndarray:
    # cosmos 转换逻辑：先 FLU，再转 OpenCV 右手坐标
    camera_to_world_flu = vehicle_to_world @ camera_to_vehicle
    camera_to_world_opencv = np.concatenate(
        [
            -camera_to_world_flu[:, 1:2],  # -Y
            -camera_to_world_flu[:, 2:3],  # -Z
            camera_to_world_flu[:, 0:1],   #  X
            camera_to_world_flu[:, 3:4],   #  T
        ],
        axis=1,
    )
    return camera_to_world_opencv


def soften_sky_mask(mask: np.ndarray, blur_sigma: float = 5.0, erode_kernel: int = 3) -> np.ndarray:
    """
    软化天空mask的边缘，让值在0-1之间平滑过渡，并略微缩小天空区域。
    
    Args:
        mask: 二值mask，值为0或255
        blur_sigma: 高斯模糊的sigma值，越大边缘越平滑
        erode_kernel: 形态学腐蚀的核大小，用于略微缩小天空区域
    
    Returns:
        软化后的mask，值在0-1之间
    """
    # 归一化到0-1
    mask_normalized = mask.astype(np.float32) / 255.0
    
    # 可选的形态学腐蚀，让天空区域略微缩小
    if erode_kernel > 0:
        kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
        mask_normalized = cv2.erode(mask_normalized, kernel, iterations=1)
    
    # 高斯模糊软化边缘
    mask_softened = cv2.GaussianBlur(mask_normalized, (0, 0), blur_sigma)
    
    return mask_softened


def save_point_cloud_efficient(points: np.ndarray, colors: np.ndarray, output_path: Path, format_type: str = "npz", camera_ids: np.ndarray | None = None):
    """
    高效存储点云数据
    
    Args:
        points: 点坐标 (N, 3)
        colors: 点颜色 (N, 3)
        output_path: 输出路径
        format_type: 存储格式 ("npz", "npz_fp16", "npz_bf16", "bfloat16", "ply")
    """
    if format_type == "npz":
        # 使用压缩的numpy格式，保持float32精度
        np.savez_compressed(
            str(output_path),
            points=points.astype(np.float32),
            colors=colors.astype(np.float32),
            camera_ids=camera_ids.astype(np.uint8),
        )
        
    elif format_type == "npz_fp16":
        # numpy float16精度 + npz压缩（文件约小50%，适合点云）
        np.savez_compressed(
            str(output_path),
            points=points.astype(np.float16),
            colors=colors.astype(np.float16),
            camera_ids=camera_ids.astype(np.uint8),
        )
        
    elif format_type == "npz_bf16":
        # bfloat16精度 + 压缩存储（文件最小）
        np.savez_compressed(
            str(output_path),
            points=points.astype(np.float16),
            colors=colors.astype(np.float16),
            camera_ids=camera_ids.astype(np.uint8),
        )
        
    elif format_type == "bfloat16":
        # 使用bfloat16精度（需要torch支持）
        if points.dtype == np.float32:
            torch_points = torch.from_numpy(points).to(torch.bfloat16)
            torch_colors = torch.from_numpy(colors).to(torch.bfloat16)
            
            data = {
                'points': torch_points,
                'colors': torch_colors,
                'camera_ids': torch.from_numpy(camera_ids.astype(np.uint8)),
            }
            # 保存为压缩的pickle文件
            with gzip.open(str(output_path), 'wb') as f:
                pickle.dump(data, f)
        else:
            # 如果已经是其他类型，直接保存
            np.savez_compressed(
                str(output_path),
                points=points,
                colors=colors
            )
            
    elif format_type == "ply":
        import open3d as o3d  # type: ignore
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.io.write_point_cloud(str(output_path), pcd)
        np.savez_compressed(str(output_path) + ".meta.npz", camera_ids=camera_ids.astype(np.uint8))
    else:
        raise ValueError(f"不支持的存储格式: {format_type}")
    
    return output_path


def load_point_cloud_efficient(input_path: Path, format_type: str = "npz", return_labels: bool = False) -> tuple:
    """
    加载高效存储的点云数据
    
    Returns:
        当 return_labels=False: (points, colors)
        当 return_labels=True: (points, colors, camera_ids or None)
    """
    if format_type in ["npz", "npz_fp16", "npz_bf16"]:
        # npz_fp16和npz_bf16实际上都是使用float16存储
        data = np.load(str(input_path))
        points = data['points'].astype(np.float32)
        colors = data['colors'].astype(np.float32)
        camera_ids = None
        if 'camera_ids' in data and data['camera_ids'] is not None:
            camera_ids = data['camera_ids'].astype(np.uint8)
        return (points, colors, camera_ids) if return_labels else (points, colors)
        
    elif format_type == "bfloat16":
        # 从压缩的pickle文件加载（torch tensor格式）
        with gzip.open(str(input_path), 'rb') as f:
            data = pickle.load(f)
        if isinstance(data['points'], torch.Tensor):
            points = data['points'].to(torch.float32).numpy()
            colors = data['colors'].to(torch.float32).numpy()
        else:
            points = data['points'].astype(np.float32)
            colors = data['colors'].astype(np.float32)
        camera_ids = None
        if isinstance(data, dict) and ('camera_ids' in data) and (data['camera_ids'] is not None):
            cam = data['camera_ids']
            if isinstance(cam, torch.Tensor):
                camera_ids = cam.to(torch.uint8).cpu().numpy()
            else:
                camera_ids = np.asarray(cam).astype(np.uint8)
        return (points, colors, camera_ids) if return_labels else (points, colors)
        
    elif format_type == "ply":
        import open3d as o3d  # type: ignore
        pcd = o3d.io.read_point_cloud(str(input_path))
        points = np.asarray(pcd.points)
        colors = np.asarray(pcd.colors)
        camera_ids = None
        meta_path = str(input_path) + ".meta.npz"
        if Path(meta_path).exists():
            try:
                meta = np.load(meta_path)
                if 'camera_ids' in meta:
                    camera_ids = meta['camera_ids'].astype(np.uint8)
            except Exception:
                camera_ids = None
        return (points, colors, camera_ids) if return_labels else (points, colors)
    else:
        raise ValueError(f"不支持的存储格式: {format_type}")


def create_point_cloud_from_depth(color_image: np.ndarray, depth: np.ndarray, K: np.ndarray, dist: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
    """
    从深度图和彩色图像生成点云。
    
    Args:
        color_image: RGB 图像 (H, W, 3)
        depth: 深度图 (H, W)
        K: 相机内参矩阵 3x3 (fx, fy, cx, cy)
        dist: 相机畸变参数 (k1, k2, p1, p2, k3)
    
    Returns:
        (points, colors): numpy 数组
    """
    H, W = depth.shape
    
    # 生成像素坐标网格
    x_pix, y_pix = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32))

    # 通过 undistortPoints 得到去畸变的归一化坐标
    pts = np.stack([x_pix.reshape(-1), y_pix.reshape(-1)], axis=1).reshape(-1, 1, 2)
    undist_norm = cv2.undistortPoints(pts, K, dist, P=None)  # 返回形状: (N, 1, 2)，已在归一化平面
    x = undist_norm[:, 0, 0].reshape(H, W)
    y = undist_norm[:, 0, 1].reshape(H, W)
    
    # 获取深度 z
    z = depth
    
    # 计算 3D 点：(x*z, y*z, z)
    points = np.stack([
        np.multiply(x, z),
        np.multiply(y, z),
        z
    ], axis=-1).reshape(-1, 3)
    
    # 颜色归一化到 [0, 1]
    colors = color_image.reshape(-1, 3) / 255.0
    
    # 过滤掉无效深度（z <= 0）
    valid_mask = points[:, 2] > 0
    points = points[valid_mask]
    colors = colors[valid_mask]
    
    return points, colors

DUMMY_DISTANCE_VALUE = 2e3  # meters, 用于缺失点占位


def convert_range_image_to_point_cloud(
    range_image: _v2_lidar.RangeImage,
    calibration: _v2_context.LiDARCalibrationComponent,
    pixel_pose: _v2_lidar.PoseRangeImage | None = None,
    frame_pose: _v2_pose.VehiclePoseComponent | None = None,
    keep_polar_features: bool = False,
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """
    将 LiDAR range image 转为笛卡尔点云，并返回缺失点与掩码。

    返回:
        points_tensor: [N, 3] 有效点
        missing_points_tensor: [M, 3] 缺失点对应的伪点（用于统计/占位）
        range_image_mask: 原始 range image 上有效与缺失的布尔掩码
    """
    # 缺失点用大值占位，便于下游创建掩码
    val_clone = deepcopy(range_image.tensor.numpy())  # type: ignore
    no_return = val_clone[..., 0] == -1
    val_clone[..., 0][no_return] = DUMMY_DISTANCE_VALUE
    object.__setattr__(range_image, "values", val_clone.flatten())

    # 如提供像素/车体位姿，则用车体位姿补齐缺失像素的姿态
    if pixel_pose is not None:
        assert frame_pose is not None
        pixel_pose_clone = deepcopy(pixel_pose.tensor.numpy())  # type: ignore
        pixel_pose_mask = pixel_pose_clone[..., 0] == 0
        tr_orig = tf.reshape(tf.convert_to_tensor(frame_pose.world_from_vehicle.transform), (4, 4)).numpy()  # type: ignore
        rot = tr_orig[:3, :3]
        x, y, z = tr_orig[:3, 3]
        yaw, pitch, roll = transforms3d.euler.mat2euler(rot, "szyx")
        # [roll, pitch, yaw, x, y, z]
        pixel_pose_clone[..., 0][pixel_pose_mask] = roll
        pixel_pose_clone[..., 1][pixel_pose_mask] = pitch
        pixel_pose_clone[..., 2][pixel_pose_mask] = yaw
        pixel_pose_clone[..., 3][pixel_pose_mask] = x
        pixel_pose_clone[..., 4][pixel_pose_mask] = y
        pixel_pose_clone[..., 5][pixel_pose_mask] = z
        object.__setattr__(pixel_pose, "values", pixel_pose_clone.flatten())

    # 极坐标 → 笛卡尔
    range_image_cartesian = v2_convert_range_image_to_cartesian(
        range_image=range_image,
        calibration=calibration,
        pixel_pose=pixel_pose,
        frame_pose=frame_pose,
        keep_polar_features=keep_polar_features,
    )

    range_image_tensor = range_image.tensor
    range_image_mask = DUMMY_DISTANCE_VALUE / 2 > range_image_tensor[..., 0]  # type: ignore
    points_tensor = tf.gather_nd(range_image_cartesian, tf.compat.v1.where(range_image_mask))
    missing_points_tensor = tf.gather_nd(range_image_cartesian, tf.compat.v1.where(~range_image_mask))

    return points_tensor, missing_points_tensor, range_image_mask


def extract_pointwise_camera_projection(
    range_image: _v2_lidar.RangeImage,
    camera_projection: _v2_lidar.CameraProjectionRangeImage,
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """提取有效点与缺失点的相机投影信息（相机ID, u, v），并返回掩码。"""
    range_image_tensor = range_image.tensor
    range_image_mask = DUMMY_DISTANCE_VALUE / 2 > range_image_tensor[..., 0]  # type: ignore
    camera_project_tensor = camera_projection.tensor
    pointwise_camera_projection_tensor = tf.gather_nd(
        camera_project_tensor, tf.compat.v1.where(range_image_mask)
    )
    missing_points_camera_projection_tensor = tf.gather_nd(
        camera_project_tensor, tf.compat.v1.where(~range_image_mask)
    )
    return (
        pointwise_camera_projection_tensor,
        missing_points_camera_projection_tensor,
        range_image_mask,
    )


def compute_sparse_depth_for_camera(
    image: np.ndarray,
    camera_index: int,
    camera_projections: np.ndarray,
    xyzs_world: np.ndarray,
    camera_extrinsic_c2w: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """根据点的投影与相机外参生成稀疏深度图。"""
    h, w = image.shape[:2]

    depth = (np.ones((h, w)) * np.finfo(np.float32).max).reshape(-1)

    num_pts = xyzs_world.shape[0]
    pts_idx = np.arange(num_pts)
    pts_idx = np.tile(pts_idx[..., None], (1, 2)).reshape(-1)  # (num_pts * 2)
    pts_camera_id = camera_projections.reshape(-1, 3)[:, 0]
    mask_depth_idx = (pts_camera_id == camera_index + 1)
    mask_depth = pts_idx[mask_depth_idx]

    if mask_depth.size == 0:
        zero_mask = np.zeros((h, w), dtype=np.bool_)
        return zero_mask, np.zeros((0,), dtype=np.float32), np.zeros((h, w), dtype=np.float32)

    xyzs_mask = xyzs_world[mask_depth]
    xyzs_mask_homo = np.concatenate([xyzs_mask, np.ones_like(xyzs_mask[..., :1])], axis=-1)

    c2w = camera_extrinsic_c2w
    w2c = np.linalg.inv(c2w)
    xyzs_mask_cam = xyzs_mask_homo @ w2c.T
    xyzs_mask_depth = xyzs_mask_cam[..., 2]
    xyzs_mask_depth = np.clip(xyzs_mask_depth, a_min=1e-1, a_max=1e2)

    u_depth, v_depth = camera_projections.reshape(-1, 3)[mask_depth_idx, 1], camera_projections.reshape(-1, 3)[mask_depth_idx, 2]
    u_depth = np.clip(u_depth, 0, w - 1).astype(np.int32)
    v_depth = np.clip(v_depth, 0, h - 1).astype(np.int32)
    indices = v_depth * w + u_depth

    np.minimum.at(depth, indices, xyzs_mask_depth)
    depth[depth >= np.finfo(np.float32).max - 1e-5] = 0

    depth_map_hw = depth.reshape(h, w).astype(np.float32)
    depth_mask_hw = (depth_map_hw != 0)
    depth_values_1d = depth[depth != 0].astype(np.float32)

    return depth_mask_hw, depth_values_1d, depth_map_hw


@ray.remote(num_cpus=1)
def save_point_cloud_task(points_np, colors_np, output_path_str, format_type, camera_ids_np):
    points = np.asarray(points_np)
    colors = np.asarray(colors_np)
    camera_ids = np.asarray(camera_ids_np)
    save_point_cloud_efficient(points, colors, Path(output_path_str), format_type=format_type, camera_ids=camera_ids)
    return output_path_str


@ray.remote(num_gpus=1)
class ClipRunner:
    def __init__(self, cfg_remote):
        import tensorflow as _tf_local
        _tf_local.config.set_visible_devices([], 'GPU')
        from utils_infer import Trainer as _Trainer
        import torch as _torch
        self._torch = _torch
        self.image_mean = _torch.tensor([90.9950, 96.2278, 94.3213], dtype=_torch.float32)[None, :, None, None]
        self.image_std = _torch.tensor([79.2382, 80.5267, 82.1483], dtype=_torch.float32)[None, :, None, None]
        self.trainer = _Trainer(cfg_remote)
        self.net = self.trainer.net_ema.module.cuda()
        self.net.eval()

    def run_subset(self, tfrecord_files_subset: list[str]) -> int:
        output_root = Path(OUTPUT_ROOT)
        output_root.mkdir(parents=True, exist_ok=True)
        camera_names_target = ['front', 'front_left', 'front_right', 'side_left', 'side_right']
        processed = 0
        for tfrecord_path_str in tfrecord_files_subset:
            tfrecord_path = Path(tfrecord_path_str)
            # 解析 clip_id
            clip_id = tfrecord_path.stem.lstrip('segment-').rstrip('_with_camera_labels')

            # 载入 TFRecord（用于标定）
            dataset = tf.data.TFRecordDataset(str(tfrecord_path), compression_type='')

            # 先从第一帧取相机标定（外参+内参+畸变）
            camera_name_to_calib = {}
            for data in dataset.take(1):
                frame = dataset_pb2.Frame()
                frame.ParseFromString(bytearray(data.numpy()))
                for calib in frame.context.camera_calibrations:
                    cam_name = camera_name_from_enum(calib.name)
                    if cam_name not in camera_names_target:
                        continue
                    # Waymo 提供的 intrinsic: [fx, fy, cx, cy, k1, k2, p1, p2, k3]
                    intrinsic = np.array(calib.intrinsic, dtype=np.float32)
                    K = build_camera_intrinsics_from_calib(calib)
                    # 保存畸变参数（k1, k2, p1, p2, k3）
                    dist = intrinsic[4:9] if intrinsic.shape[0] >= 9 else np.zeros((5,), dtype=np.float32)
                    camera_to_vehicle = np.array(calib.extrinsic.transform, dtype=np.float32).reshape(4, 4)
                    camera_name_to_calib[cam_name] = {
                        'K': K,
                        'camera_to_vehicle': camera_to_vehicle,
                        'size': (int(calib.width), int(calib.height)),
                        'camera_id': int(calib.name),
                        'dist': dist,
                    }

            if len(camera_name_to_calib) == 0:
                print(f'  警告: 未找到目标相机的标定信息，跳过此文件')
                continue

            # 估算源帧率（基于前若干帧时间戳）并统计总帧数
            probe = tf.data.TFRecordDataset(str(tfrecord_path), compression_type='')
            timestamps = []
            for i, d in enumerate(probe):
                if i >= 50:
                    break
                fr = dataset_pb2.Frame()
                fr.ParseFromString(bytearray(d.numpy()))
                timestamps.append(int(fr.timestamp_micros))
            fps_est = None
            if len(timestamps) >= 2:
                diffs = np.diff(np.array(timestamps, dtype=np.int64))
                diffs = diffs[diffs > 0]
                if diffs.size > 0:
                    median_dt = float(np.median(diffs))  # microseconds
                    fps_est = 1e6 / median_dt
            # 决定有效下采样倍率：若源≈30Hz，使用 DOWNSAMPLE；若源≈10Hz，则置为1
            if fps_est is not None and fps_est > 15:
                effective_downsample = max(1, int(DOWNSAMPLE))
            else:
                effective_downsample = 1

            # 确定输出文件扩展名（用于已存在文件的快速跳过）
            if POINT_CLOUD_STORAGE_FORMAT in ["npz", "npz_fp16", "npz_bf16"]:
                file_ext = ".npz"
            elif POINT_CLOUD_STORAGE_FORMAT == "bfloat16":
                file_ext = ".pkl.gz"
            else:
                file_ext = ".ply"

            # 遍历所有帧，按有效下采样倍率选择帧
            dataset = tf.data.TFRecordDataset(str(tfrecord_path), compression_type='')
            real_frame_idx = 0  # 实际处理的帧索引（不计跳过的帧）
            for frame_idx, data in enumerate(tqdm(dataset, desc=f'处理帧 ({clip_id})')):
                if frame_idx % effective_downsample != 0:
                    continue

                # 若对应输出点云已存在则跳过该处理帧
                ply_dir = Path(OUTPUT_ROOT) / clip_id
                expected_out = ply_dir / f"{real_frame_idx:06d}_merged{file_ext}"
                if expected_out.exists():
                    print(f'[skip] {expected_out} already exists')
                    real_frame_idx += 1
                    continue

                frame = dataset_pb2.Frame()
                frame.ParseFromString(bytearray(data.numpy()))

                # 车辆位姿（FLU）
                vehicle_to_world = np.array(frame.pose.transform, dtype=np.float32).reshape(4, 4)

                # 使用 v2 组件从 v1 Frame 构造 range image & camera projection，并调用 v2 转换
                # 车辆位姿 -> v2 VehiclePoseComponent
                vehicle_pose_comp = _v2_pose.VehiclePoseComponent(
                    key=_v2_base.FrameKey(
                        segment_context_name=frame.context.name,
                        frame_timestamp_micros=frame.timestamp_micros,
                    ),
                    world_from_vehicle=column_types.Transform(
                        transform=list(frame.pose.transform)
                    ),
                )

                # 构建 LiDAR 名称 -> 校准映射
                lidar_calib_map: dict[int, _v2_context.LiDARCalibrationComponent] = {}
                for c in frame.context.laser_calibrations:
                    if len(c.beam_inclinations) > 0:
                        beam_incl = _v2_context.BeamInclination(
                            min=c.beam_inclination_min,
                            max=c.beam_inclination_max,
                            values=list(c.beam_inclinations),
                        )
                    else:
                        beam_incl = _v2_context.BeamInclination(
                            min=c.beam_inclination_min,
                            max=c.beam_inclination_max,
                        )
                    lidar_calib_map[c.name] = _v2_context.LiDARCalibrationComponent(
                        key=_v2_base.SegmentLaserKey(
                            segment_context_name=frame.context.name,
                            laser_name=c.name,
                        ),
                        extrinsic=column_types.Transform(
                            transform=list(c.extrinsic.transform)
                        ),
                        beam_inclination=beam_incl,
                    )

                # 逐 LiDAR 读取 range image、camera projection 与 pixel pose
                lidar_points_per_sensor = []
                lidar_cp_per_sensor = []
                for laser in frame.lasers:
                    calib_comp = lidar_calib_map[laser.name]
                    # return1
                    ri1 = v2_parse_range_image(
                        laser.ri_return1.range_image_compressed,
                        _v2_lidar.RangeImage,
                    )
                    cp1 = v2_parse_range_image(
                        laser.ri_return1.camera_projection_compressed,
                        _v2_lidar.CameraProjectionRangeImage,
                    )
                    pose1 = None
                    if laser.name == dataset_pb2.LaserName.TOP and len(laser.ri_return1.range_image_pose_compressed) > 0:
                        pose_bytes = tf.io.decode_compressed(
                            laser.ri_return1.range_image_pose_compressed, 'ZLIB'
                        )
                        mat = dataset_pb2.MatrixFloat()
                        mat.ParseFromString(bytearray(pose_bytes.numpy()))
                        pose1 = _v2_lidar.PoseRangeImage(
                            values=list(mat.data),
                            shape=list(mat.shape.dims),
                        )

                    points_list = []
                    cp_list = []
                    if ri1 is not None and cp1 is not None:
                        pts1, _, _ = convert_range_image_to_point_cloud(
                            ri1, calib_comp, pixel_pose=pose1, frame_pose=vehicle_pose_comp, keep_polar_features=False
                        )
                        cp_pts1, _, _ = extract_pointwise_camera_projection(ri1, cp1)
                        points_list.append(pts1.numpy())
                        cp_list.append(cp_pts1.numpy())

                    # return2
                    if len(laser.ri_return2.range_image_compressed) > 0:
                        ri2 = v2_parse_range_image(
                            laser.ri_return2.range_image_compressed,
                            _v2_lidar.RangeImage,
                        )
                        cp2 = v2_parse_range_image(
                            laser.ri_return2.camera_projection_compressed,
                            _v2_lidar.CameraProjectionRangeImage,
                        )
                        if ri2 is not None and cp2 is not None:
                            pts2, _, _ = convert_range_image_to_point_cloud(
                                ri2, calib_comp, pixel_pose=None, frame_pose=vehicle_pose_comp, keep_polar_features=False
                            )
                            cp_pts2, _, _ = extract_pointwise_camera_projection(ri2, cp2)
                            points_list.append(pts2.numpy())
                            cp_list.append(cp_pts2.numpy())

                    if points_list:
                        lidar_points_per_sensor.append(np.concatenate(points_list, axis=0))
                        lidar_cp_per_sensor.append(np.concatenate(cp_list, axis=0))

                # 合并所有 LiDAR
                points = lidar_points_per_sensor
                cps = lidar_cp_per_sensor

                # 融合所有 LiDAR（TOP/FRONT/SIDE/REAR 等）
                lidar_ids = [calib.name for calib in frame.context.laser_calibrations]
                lidar_ids.sort()
                all_points_vehicle_list = []
                all_cp_points_list = []
                for lidar_id, lidar_points in zip(lidar_ids, points):
                    lidar_name = lidar_name_from_enum(lidar_id)
                    # if lidar_name != 'TOP':
                    #     continue
                    all_points_vehicle_list.append(lidar_points.astype(np.float32))
                # cp_points 与 points 顺序一致
                for _lidar_id, lidar_cps in zip(lidar_ids, cps):
                    all_cp_points_list.append(lidar_cps.astype(np.float32))
                if len(all_points_vehicle_list) == 0:
                    raise RuntimeError('未从任意 LiDAR 读取到点云')
                all_points_vehicle_raw = np.concatenate(all_points_vehicle_list, axis=0)
                all_cp_points_raw = np.concatenate(all_cp_points_list, axis=0)

                # 过滤近距离
                dists = np.linalg.norm(all_points_vehicle_raw, axis=1)
                keep_mask = dists >= 3.0
                all_points_vehicle = all_points_vehicle_raw[keep_mask]
                all_cp_points = all_cp_points_raw[keep_mask]

                # 车辆 -> 世界（一次性变换所有点）
                homo = np.concatenate([all_points_vehicle, np.ones((all_points_vehicle.shape[0], 1), dtype=np.float32)], axis=1)
                all_points_world = (vehicle_to_world @ homo.T).T[:, :3]

                # 收集本帧图像（所有目标相机）
                cam_images = {}
                for image_data in frame.images:
                    cam_name = camera_name_from_enum(image_data.name)
                    if cam_name not in camera_name_to_calib:
                        continue
                    img_bytes = image_data.image
                    img = tf.image.decode_jpeg(img_bytes).numpy()  # RGB HxWx3
                    cam_images[cam_name] = img

                # 遍历相机：并行生成稀疏深度与准备批处理输入
                per_cam_results = {}
                def _process_camera(cam_idx, cam_name):
                    if cam_name not in cam_images:
                        return None
                    cam_number = cam_idx
                    calib = camera_name_to_calib[cam_name]
                    K = calib['K']
                    dist_coeffs = calib.get('dist', np.zeros((5,), dtype=np.float32))
                    camera_to_vehicle = calib['camera_to_vehicle']
                    camera_to_world = compute_camera_to_world_opencv(vehicle_to_world, camera_to_vehicle)
                    img = cam_images[cam_name]
                    H, W = int(img.shape[0]), int(img.shape[1])
                    # 稀疏深度
                    camera_index = cam_idx
                    cp_all = all_cp_points
                    _, _, depth_map_hw = compute_sparse_depth_for_camera(
                        img, camera_index, cp_all, all_points_world, camera_to_world
                    )
                    # 到网络步长（32）填充
                    stride = 32
                    pad_h = (int(np.ceil(H / stride)) * stride) - H
                    pad_w = (int(np.ceil(W / stride)) * stride) - W
                    padded_image = np.zeros((H + pad_h, W + pad_w, 3), dtype=np.float32)
                    padded_image[:H, :W] = img.astype(np.float32)
                    padded_depth = np.zeros((H + pad_h, W + pad_w), dtype=np.float32)
                    padded_depth[:H, :W] = depth_map_hw
                    return {
                        'cam_name': cam_name,
                        'cam_number': cam_number,
                        'H': H,
                        'W': W,
                        'pad_h': pad_h,
                        'pad_w': pad_w,
                        'padded_image': padded_image,
                        'padded_depth': padded_depth,
                        'K': K.astype(np.float32),
                        'camera_to_world': camera_to_world.astype(np.float32),
                        'dist': dist_coeffs,
                    }

                with ThreadPoolExecutor(max_workers=CPU_WORKERS) as pool:
                    futures = {pool.submit(_process_camera, idx, nm): nm for idx, nm in enumerate(camera_names_target)}
                    for fut in as_completed(futures):
                        res = fut.result()
                        if res is not None:
                            per_cam_results[res['cam_name']] = res

                if not per_cam_results:
                    # 无可用相机
                    real_frame_idx += 1
                    continue

                # 逐相机（batch=1）推理，避免 OOM，不做批次拼接
                cam_order = list(per_cam_results.keys())
                frame_depths = {}
                frame_colors = {nm: cam_images[nm] for nm in cam_order}
                frame_intrinsics = {nm: per_cam_results[nm]['K'] for nm in cam_order}
                frame_camera_to_world = {nm: per_cam_results[nm]['camera_to_world'] for nm in cam_order}
                frame_dists = {nm: per_cam_results[nm]['dist'] for nm in cam_order}

                for cam_name in cam_order:
                    it = per_cam_results[cam_name]
                    H, W = it['H'], it['W']
                    img_pad = it['padded_image']
                    dep_pad = it['padded_depth']
                    K_np = it['K']

                    imgs = self._torch.from_numpy(img_pad.astype('float32')).permute(2, 0, 1)[None, ...]
                    deps = self._torch.from_numpy(dep_pad.astype('float32'))[None, None, ...]
                    Ks = self._torch.from_numpy(K_np.astype('float32'))[None, ...]
                    imgs = (imgs - self.image_mean) / self.image_std
                    with self._torch.no_grad():
                        out = self.net(imgs.cuda(), None, deps.cuda(), Ks.cuda())
                        if isinstance(out, (list, tuple)):
                            out = out[-1]
                    pred_full = out.detach().cpu().numpy()[0]
                    pred = pred_full[0, :H, :W]

                    # 天空 mask（逐相机应用）
                    sky_mask_path = Path(SKY_MASK_DIR) / clip_id / f"{real_frame_idx:06d}_{it['cam_number']}.png"
                    if sky_mask_path.exists():
                        sky_mask = cv2.imread(str(sky_mask_path), cv2.IMREAD_GRAYSCALE)
                        if sky_mask is not None:
                            if sky_mask.shape != (H, W):
                                sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_NEAREST)
                            sky_mask_softened = soften_sky_mask(sky_mask, blur_sigma=MASK_BLUR_SIGMA, erode_kernel=MASK_ERODE_KERNEL)
                            pred = (1 - sky_mask_softened) * pred + sky_mask_softened * SKY_DEPTH_VALUE
                            pred = np.clip(pred, 0, SKY_DEPTH_VALUE)

                    frame_depths[cam_name] = pred

                    # 可视化与保存（按相机与帧组织）
                    if VISUALIZE:
                        vis_dir = Path(OUTPUT_ROOT) / clip_id / cam_name
                        vis_dir.mkdir(parents=True, exist_ok=True)

                        # 原图（RGB）
                        rgb_img = cam_images[cam_name]
                        cv2.imwrite(str(vis_dir / f"{real_frame_idx:06d}_rgb.jpg"), rgb_img[:, :, ::-1])

                        # 稀疏深度（来自 dep_pad，裁剪到原尺寸）
                        sparse_depth = dep_pad[:H, :W]
                        # 归一化到 [0, 255] 后伪彩
                        sd_max = max(1e-6, float(np.percentile(sparse_depth[sparse_depth > 0], 95)) if np.any(sparse_depth > 0) else 1.0)
                        sd_norm = np.clip((sparse_depth / sd_max) * 255.0, 0, 255).astype(np.uint8)
                        sd_color = cv2.applyColorMap(sd_norm, cv2.COLORMAP_JET)
                        cv2.imwrite(str(vis_dir / f"{real_frame_idx:06d}_sparse_depth.jpg"), sd_color)

                        # 预测深度伪彩
                        pd_max = max(1e-6, float(np.percentile(pred[pred > 0], 99)) if np.any(pred > 0) else 1.0)
                        pd_norm = np.clip((pred / pd_max) * 255.0, 0, 255).astype(np.uint8)
                        pd_color = cv2.applyColorMap(pd_norm, cv2.COLORMAP_JET)
                        cv2.imwrite(str(vis_dir / f"{real_frame_idx:06d}_pred_depth.jpg"), pd_color)

                        # 叠加显示（伪彩深度覆盖到原图上，带透明度）
                        overlay = (0.6 * pd_color + 0.4 * rgb_img[:, :, ::-1]).astype(np.uint8)
                        cv2.imwrite(str(vis_dir / f"{real_frame_idx:06d}_overlay.jpg"), overlay)
                
                if frame_depths:
                    if REMOVE_OVERLAP:
                        # 0. 定义相机优先级 (0 = 最高)
                        camera_priority = {
                            'front': 0,
                            'front_left': 1,
                            'front_right': 1,
                            'side_left': 2,
                            'side_right': 2
                        }
                        
                        # 1. 预计算所有相机的参数和相机空间点云（用于重叠区域裁切）
                        cam_params_for_masking = {}
                        cam_grids_cam_space = {}
                        
                        # 计算世界到ego的变换（后续点云生成时会用到）
                        world_to_vehicle = np.linalg.inv(vehicle_to_world)

                        for cam_name in cam_order:
                            if cam_name not in frame_depths:
                                continue
                            
                            it = per_cam_results[cam_name]
                            K = it['K']
                            dist = it.get('dist', np.zeros((5,), dtype=np.float32))
                            H, W = it['H'], it['W']
                            depth = frame_depths[cam_name]  # H, W
                            x_pix, y_pix = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32))
                            
                            # 通过 undistortPoints 得到去畸变的归一化坐标
                            pts = np.stack([x_pix.reshape(-1), y_pix.reshape(-1)], axis=1).reshape(-1, 1, 2)
                            undist_norm = cv2.undistortPoints(pts, K, dist, P=None)
                            x_cam_norm = undist_norm[:, 0, 0].reshape(H, W)
                            y_cam_norm = undist_norm[:, 0, 1].reshape(H, W)
                                
                            # (H, W, 3) in camera space
                            points_cam = np.stack([
                                np.multiply(x_cam_norm, depth),
                                np.multiply(y_cam_norm, depth),
                                depth
                            ], axis=-1)
                            
                            cam_grids_cam_space[cam_name] = points_cam
                            
                            # 计算FOV范围
                            fx, fy = K[0, 0], K[1, 1]
                            hfov_half = np.arctan(W / 2.0 / fx)
                            vfov_half = np.arctan(H / 2.0 / fy)
                            # 预计算tan值，避免在循环中重复计算
                            tan_hfov_half = np.tan(hfov_half)
                            tan_vfov_half = np.tan(vfov_half)
                            
                            cam_params_for_masking[cam_name] = {
                                'c2w': frame_camera_to_world[cam_name],
                                'w2c': np.linalg.inv(frame_camera_to_world[cam_name]),
                                'K': K,
                                'dist': dist,
                                'H': H,
                                'W': W,
                                'tan_hfov_half': tan_hfov_half,
                                'tan_vfov_half': tan_vfov_half,
                            }

                        # 2. 执行交叉检查和掩码
                        # 浅拷贝字典，深度复制时按需进行
                        masked_frame_depths = {k: v.copy() for k, v in frame_depths.items()}
                        
                        # 预计算相机优先级映射和高优先级相机列表
                        cam_priorities = {name: camera_priority.get(name, 99) for name in cam_order}
                        
                        # 遍历我们定义好的相机列表
                        for source_cam_name in cam_order:
                            if source_cam_name not in cam_params_for_masking:
                                continue
                                
                            source_priority = cam_priorities[source_cam_name]
                            
                            # 如果是最高优先级 (front)，它不需要被裁切，跳过
                            if source_priority == 0:
                                continue
                                
                            source_params = cam_params_for_masking[source_cam_name]
                            source_points_cam = cam_grids_cam_space[source_cam_name]  # (H, W, 3)
                            H, W = source_params['H'], source_params['W']
                            source_depth_flat = source_points_cam[..., 2].reshape(H * W)  # (N,)
                            
                            # 转换为世界坐标系（优化矩阵乘法：避免转置）
                            source_points_cam_flat = source_points_cam.reshape(-1, 3)
                            c2w = source_params['c2w']
                            # 直接计算 (c2w @ points.T).T 更高效
                            source_points_world = (c2w[:3, :3] @ source_points_cam_flat.T + c2w[:3, 3:4]).T  # (N, 3)
                            source_points_world_homo = np.column_stack([source_points_world, np.ones((source_points_world.shape[0], 1))])  # (N, 4)
                            
                            # 排除无效点 (原始深度为0)
                            valid_source_mask = source_depth_flat > 1e-3
                            if not np.any(valid_source_mask):
                                continue

                            # 初始化一个掩码，标记源相机中需要丢弃的像素
                            discard_mask_flat = np.zeros_like(valid_source_mask, dtype=bool)
                            
                            # 预计算高优先级相机列表（只检查可能裁切源相机的相机）
                            higher_priority_cams = [name for name in cam_order 
                                                  if name != source_cam_name 
                                                  and name in cam_params_for_masking 
                                                  and cam_priorities[name] < source_priority]

                            # 遍历高优先级相机，检查它们是否可以裁切源相机
                            for other_cam_name in higher_priority_cams:
                                other_params = cam_params_for_masking[other_cam_name]

                                # 将源点云转换到其他(高优先级)相机坐标系（优化矩阵乘法）
                                w2c = other_params['w2c']
                                other_points_cam = (w2c[:3, :3] @ source_points_world.T + w2c[:3, 3:4]).T  # (N, 3)
                                other_points_cam_z = other_points_cam[:, 2]  # (N,)

                                # 找出在其他相机前方的点
                                valid_z_mask = (other_points_cam_z > 1e-3)

                                # 只处理那些在源相机中有效且在目标相机前方的点
                                check_mask = valid_source_mask & valid_z_mask
                                if not np.any(check_mask):
                                    continue

                                # 获取在相机前方的点（相机坐标系）
                                points_cam_valid = other_points_cam[check_mask]  # (M, 3)
                                
                                # 归一化到相机归一化平面（x/z, y/z）
                                z_valid = points_cam_valid[:, 2]  # (M,)
                                x_norm = points_cam_valid[:, 0] / z_valid  # (M,)
                                y_norm = points_cam_valid[:, 1] / z_valid  # (M,)
                                
                                # 检查FOV范围（使用预计算的tan值）
                                tan_hfov_half = other_params['tan_hfov_half']
                                tan_vfov_half = other_params['tan_vfov_half']
                                
                                # 判断是否在FOV范围内
                                in_fov_mask = (
                                    (np.abs(x_norm) <= tan_hfov_half) & 
                                    (np.abs(y_norm) <= tan_vfov_half)
                                )
                                
                                if not np.any(in_fov_mask):
                                    continue
                                
                                # 找出全局索引（在N个点中）
                                global_check_indices = np.where(check_mask)[0]
                                global_indices_in_fov = global_check_indices[in_fov_mask]

                                # 标记这些点为丢弃（只有在FOV范围内才裁切）
                                discard_mask_flat[global_indices_in_fov] = True

                            # 3. 应用掩码
                            if np.any(discard_mask_flat):
                                masked_frame_depths[source_cam_name][discard_mask_flat.reshape(H, W)] = 0.0  # 挖洞
                    else:
                        # 不删除重叠区域：直接使用原始深度图
                        masked_frame_depths = frame_depths
                        # 仍需计算世界到ego的变换以生成点云
                        world_to_vehicle = np.linalg.inv(vehicle_to_world)

                    # 4. 使用处理后的深度图生成点云
                    all_points_ego = []
                    all_colors = []
                    all_cam_ids = []
                    
                    for cam_name in masked_frame_depths.keys():
                        depth = masked_frame_depths[cam_name]
                        
                        color = frame_colors[cam_name]
                        K = frame_intrinsics[cam_name]
                        camera_to_world = frame_camera_to_world[cam_name]
                        cur_dist = frame_dists[cam_name]
                        
                        # create_point_cloud_from_depth 会自动过滤 Z <= 0 的点
                        points_cam, colors = create_point_cloud_from_depth(color, depth, K, cur_dist)
                        
                        # 如果这个相机的点云被完全 "挖空" 了，则跳过
                        if points_cam.shape[0] == 0:
                            continue
                            
                        # 将点云从相机坐标系变换到世界坐标系，再变换到ego坐标系
                        # 相机坐标系 -> 世界坐标系（优化矩阵乘法）
                        points_world = (camera_to_world[:3, :3] @ points_cam.T + camera_to_world[:3, 3:4]).T  # (N, 3)
                        
                        # 世界坐标系 -> ego坐标系（优化矩阵乘法）
                        points_ego = (world_to_vehicle[:3, :3] @ points_world.T + world_to_vehicle[:3, 3:4]).T  # (N, 3)
                        
                        all_points_ego.append(points_ego)
                        all_colors.append(colors)
                        cam_idx_this = int(camera_names_target.index(cam_name))
                        all_cam_ids.append(np.full((points_ego.shape[0],), cam_idx_this, dtype=np.uint8))

                    if not all_points_ego:
                        real_frame_idx += 1
                        continue  # 如果所有点都被过滤了，跳到下一帧

                    all_points = np.concatenate(all_points_ego, axis=0)
                    all_colors_merged = np.concatenate(all_colors, axis=0)
                    all_cam_ids_merged = np.concatenate(all_cam_ids, axis=0)

                    # 保存点云（使用高效存储格式，Ray 异步任务）
                    ply_dir = output_root / clip_id
                    ply_dir.mkdir(parents=True, exist_ok=True)
                    if POINT_CLOUD_STORAGE_FORMAT in ["npz", "npz_fp16", "npz_bf16"]:
                        file_ext = ".npz"
                    elif POINT_CLOUD_STORAGE_FORMAT == "bfloat16":
                        file_ext = ".pkl.gz"
                    else:
                        file_ext = ".ply"

                    _ = save_point_cloud_task.remote(
                        all_points.astype(np.float32),
                        all_colors_merged.astype(np.float32),
                        str(ply_dir / f"{real_frame_idx:06d}_merged{file_ext}"),
                        POINT_CLOUD_STORAGE_FORMAT,
                        all_cam_ids_merged.astype(np.uint8),
                    )
                
                # 增加实际处理的帧索引
                real_frame_idx += 1
            processed += 1
        return processed


@hydra.main(config_path='configs', config_name='config', version_base='1.2')
def main(cfg):
    # 初始化 Ray（支持 GPU/CPU 并行）
    if not ray.is_initialized():
        ray.init(ignore_reinit_error=True, include_dashboard=False)

    # 获取所有 TFRecord 文件
    tfrecord_dir = Path(INPUT_TFRECORD_DIR)
    tfrecord_files = list(tfrecord_dir.glob("*.tfrecord"))
    if len(tfrecord_files) == 0:
        raise ValueError(f'在目录 {INPUT_TFRECORD_DIR} 中未找到 TFRecord 文件')
    tfrecord_files = sorted(tfrecord_files)

    print('=' * 80)
    print('Waymo TFRecord 深度补全推理')
    print(f'- 输入目录: {INPUT_TFRECORD_DIR}')
    print(f'- 找到 {len(tfrecord_files)} 个 TFRecord 文件')
    print(f'- 下采样倍率: {DOWNSAMPLE}')
    print(f'- 输出目录: {OUTPUT_ROOT}')
    print('=' * 80)

    # 将文件平均分成 4 份（不足则尽量均匀）
    num_workers = 4
    shards = [[] for _ in range(num_workers)]
    for i, p in enumerate(tfrecord_files):
        shards[i % num_workers].append(str(p))

    # 启动 4 个 GPU Runner
    runners = [ClipRunner.remote(cfg) for _ in range(num_workers)]
    print('   ✓ 启动 4 个 GPU Runner')

    # 并行运行
    result_refs = [r.run_subset.remote(shards[i]) for i, r in enumerate(runners)]
    _ = ray.get(result_refs)


if __name__ == '__main__':
    main()


