from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Any

import numpy as np
from scipy.spatial.transform import Rotation as R
import torch
import cv2

from .camera import CameraRigSpec, CameraTransform
from .projection import ProjectionConfig, PointCloudProjector
from .utils import ImageSaver
from .gaussian_model import Gaussian

if TYPE_CHECKING:
    from .object_info import ObjectInfoProcessor
    from .clip_config import ClipConfig

# 与参考渲染逻辑一致：Z 轴 180° 旋转修正
ROTATION_FIX_Z_180 = R.from_euler('z', 180, degrees=True).as_matrix()
TRANSFORM_FIX_Z_180 = np.eye(4, dtype=np.float32)
TRANSFORM_FIX_Z_180[:3, :3] = ROTATION_FIX_Z_180


class PointCloudProcessor:
    def __init__(self, config: ProjectionConfig, projector: Optional[PointCloudProjector] = None):
        self.config = config
        self.projector = projector or PointCloudProjector(config)
        self.transform = CameraTransform()
        self.saver = ImageSaver()
        # 3DGS 渲染开关与设备
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._gs_renderer = True
        
    def _apply_hole_carving(
            self,
            points_world: np.ndarray,
            colors: np.ndarray,
            camera_ids: np.ndarray,
            K: np.ndarray,
            camera_to_world: np.ndarray,
            H: int,
            W: int,
            object_processor: Optional['ObjectInfoProcessor'],
            frame_idx_int: Optional[int],
            object_ids_to_remove: Optional[List[str]],
            scale_factor: float = 1.1,
            r_px: float = 3.0,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        根据给定相机姿态与目标 OBB，在投影平面上计算阴影并删除相应点（挖洞）。
        返回过滤后的 points/colors/camera_ids。
        """
        do_removal = (
            object_processor is not None and
            object_ids_to_remove is not None and len(object_ids_to_remove) > 0 and
            frame_idx_int is not None
        )
        if not do_removal or points_world.shape[0] == 0:
            return points_world, colors, camera_ids

        w2c = np.linalg.inv(camera_to_world).astype(np.float32)
        num_points = points_world.shape[0]
        points_world_h = np.hstack([points_world, np.ones((num_points, 1), dtype=np.float32)])
        points_cam = (w2c @ points_world_h.T).T[:, :3]
        points_depths = points_cam[:, 2]
        valid_depth_mask = points_depths > 1e-6
        points_2d_all = np.zeros((num_points, 2), dtype=np.float32)
        if np.any(valid_depth_mask):
            points_cam_valid = points_cam[valid_depth_mask]
            points_depths_valid = points_depths[valid_depth_mask]
            points_normalized = points_cam_valid / points_depths_valid[:, np.newaxis]
            points_proj_h = (K @ points_normalized.T).T
            points_2d_all[valid_depth_mask] = points_proj_h[:, :2]
        points_2d_int = np.floor(points_2d_all).astype(np.int32)
        final_remove_mask_for_cam = np.zeros(num_points, dtype=bool)

        frame_key = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
        all_frame_objects = object_processor.get_object_info().get(frame_key, {})  # type: ignore
        if not (isinstance(all_frame_objects, dict) and len(all_frame_objects) > 0):
            return points_world, colors, camera_ids

        for obj_id in object_ids_to_remove:
            obj_info = all_frame_objects.get(str(obj_id))
            if obj_info is None:
                continue

            object_to_world = np.array(obj_info['object_to_world'], dtype=np.float32)
            lwh = np.array(obj_info['object_lwh'], dtype=np.float32)
            vertices_local_h = self._get_obb_vertices(lwh, scale_factor)
            vertices_world_h = (object_to_world @ vertices_local_h.T).T
            vertices_cam_h = (w2c @ vertices_world_h.T).T
            vertices_cam = vertices_cam_h[:, :3]
            vertices_depths = vertices_cam[:, 2]
            front_mask = vertices_depths > 1e-6
            if not np.any(front_mask):
                continue

            min_depth = np.min(vertices_depths[front_mask])
            vertices_cam_front = vertices_cam[front_mask]
            vertices_depths_front = vertices_depths[front_mask]

            vert_normalized = vertices_cam_front / vertices_depths_front[:, np.newaxis]
            vert_proj_h = (K @ vert_normalized.T).T
            polygon_vertices_2d = vert_proj_h[:, :2]
            if polygon_vertices_2d.shape[0] < 3:
                continue

            shadow_polygon = cv2.convexHull(polygon_vertices_2d, returnPoints=True)
            if shadow_polygon is None or shadow_polygon.shape[0] < 3:
                continue

            remove_mask_2d = np.zeros((H, W), dtype=np.uint8)
            cv2.drawContours(remove_mask_2d, [np.round(shadow_polygon).astype(np.int32)], -1, 255, cv2.FILLED)

            mask_depth_test = (valid_depth_mask) & (points_depths > min_depth - 0.2)
            if not np.any(mask_depth_test):
                continue

            mask_bounds_x_f = (points_2d_all[:, 0] >= (-0.5 - r_px)) & (points_2d_all[:, 0] < ((W - 0.5) + r_px))
            mask_bounds_y_f = (points_2d_all[:, 1] >= (-0.5 - r_px)) & (points_2d_all[:, 1] < ((H - 0.5) + r_px))
            mask_in_bounds_f = mask_bounds_x_f & mask_bounds_y_f
            mask_to_check = mask_depth_test & mask_in_bounds_f
            if not np.any(mask_to_check):
                continue

            poly_min_x = float(np.min(polygon_vertices_2d[:, 0]))
            poly_max_x = float(np.max(polygon_vertices_2d[:, 0]))
            poly_min_y = float(np.min(polygon_vertices_2d[:, 1]))
            poly_max_y = float(np.max(polygon_vertices_2d[:, 1]))

            in_rect = (
                (points_2d_all[:, 0] >= (poly_min_x - r_px)) & (points_2d_all[:, 0] <= (poly_max_x + r_px)) &
                (points_2d_all[:, 1] >= (poly_min_y - r_px)) & (points_2d_all[:, 1] <= (poly_max_y + r_px))
            )

            mask_rect = mask_to_check & in_rect
            if not np.any(mask_rect):
                continue

            indices_to_check = np.where(mask_rect)[0]
            coords_to_check = points_2d_int[indices_to_check]
            x_idx = np.clip(coords_to_check[:, 0], 0, W - 1)
            y_idx = np.clip(coords_to_check[:, 1], 0, H - 1)
            mask_values = remove_mask_2d[y_idx, x_idx]
            inside_by_mask = (mask_values == 255)

            not_inside = ~inside_by_mask
            if np.any(not_inside):
                coords_f = points_2d_all[indices_to_check][not_inside]
                hull2d = shadow_polygon.reshape(-1, 2)
                dist_ok = np.fromiter(
                    (cv2.pointPolygonTest(hull2d, (float(x), float(y)), True) >= -r_px for x, y in coords_f),
                    dtype=bool,
                    count=coords_f.shape[0]
                )
                inside_by_mask[not_inside] = dist_ok

            remove_indices = indices_to_check[inside_by_mask]
            final_remove_mask_for_cam[remove_indices] = True

        keep_mask = ~final_remove_mask_for_cam
        return points_world[keep_mask], colors[keep_mask], camera_ids[keep_mask]

    @staticmethod
    def _get_obb_vertices(
            lwh: np.ndarray,
            scale_factor: float,
    ) -> np.ndarray:
        """
        返回OBB的8个顶点 (齐次坐标) 在其局部坐标系中的位置。
        """
        # l, w, h -> x, y, z
        half_dims_lwh = (lwh * float(scale_factor)) / 2.0

        l_half = half_dims_lwh[0]
        w_half = half_dims_lwh[1]
        z_half = half_dims_lwh[2]

        vertices = np.array([
            [l_half, w_half, -z_half],
            [l_half, -w_half, -z_half],
            [-l_half, -w_half, -z_half],
            [-l_half, w_half, -z_half],
            [l_half, w_half, z_half],
            [l_half, -w_half, z_half],
            [-l_half, -w_half, z_half],
            [-l_half, w_half, z_half],
        ], dtype=np.float32)

        # (8, 4) 齐次坐标
        return np.hstack([vertices, np.ones((8, 1), dtype=np.float32)])

    def _compute_object_2d_bboxes(
            self,
            rig: 'CameraRigSpec',
            camera_to_world: Dict[str, np.ndarray],
            object_processor: Optional['ObjectInfoProcessor'],
            frame_idx_int: Optional[int],
            object_ids_to_remove: Optional[List[str]],
            objects_to_add_info: Optional[List[Tuple[str, Path]]],
            scale_factor: float = 1.2,
    ) -> Dict[str, np.ndarray]:
        """
        计算通过3DGS添加的新对象在2D图像上的投影矩形区域（扩大scale_factor倍）。
        
        Args:
            rig: 相机配置
            camera_to_world: 每个相机的camera_to_world变换
            object_processor: 对象信息处理器
            frame_idx_int: 帧索引
            object_ids_to_remove: 要删除的对象ID列表（未使用，保留用于兼容性）
            objects_to_add_info: 要添加的对象ID和路径列表（3DGS添加的新对象）
            scale_factor: 扩大倍数（默认1.2）
            
        Returns:
            Dict[str, np.ndarray]: 每个相机对应的area图像，形状为(H, W, 3)，uint8
        """
        import cv2
        
        if object_processor is None or frame_idx_int is None:
            return {}
        
        # 只使用3DGS添加的新对象（被删除和添加的对象是完全一样的）
        target_object_ids = set()
        if objects_to_add_info is not None:
            for obj_id, _ in objects_to_add_info:
                target_object_ids.add(str(obj_id))
        
        if len(target_object_ids) == 0:
            return {}
        
        frame_key = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
        all_frame_objects = object_processor.get_object_info().get(frame_key, {})
        if not isinstance(all_frame_objects, dict) or len(all_frame_objects) == 0:
            return {}
        
        result: Dict[str, np.ndarray] = {}
        
        for cam_name, c2w in camera_to_world.items():
            # 获取相机参数
            K_orig = rig.get_K(cam_name)
            H_orig, W_orig = rig.get_size(cam_name)
            K = self.config.scale_intrinsics(K_orig, (H_orig, W_orig))
            H, W = self.config.get_target_resolution((H_orig, W_orig))
            
            # 创建空白图像
            area_image = np.zeros((H, W, 3), dtype=np.uint8)
            
            # 世界到相机变换
            w2c = np.linalg.inv(c2w).astype(np.float32)
            
            # 只遍历被删除和添加的对象
            for obj_id in target_object_ids:
                obj_info = all_frame_objects.get(str(obj_id))
                if obj_info is None:
                    continue
                object_to_world = np.array(obj_info['object_to_world'], dtype=np.float32)
                lwh = np.array(obj_info['object_lwh'], dtype=np.float32)
                
                # 获取OBB的8个顶点（使用scale_factor=1.0，因为我们后面会单独扩大）
                vertices_local_h = self._get_obb_vertices(lwh, scale_factor=1.0)
                vertices_world_h = (object_to_world @ vertices_local_h.T).T
                vertices_cam_h = (w2c @ vertices_world_h.T).T
                vertices_cam = vertices_cam_h[:, :3]
                vertices_depths = vertices_cam[:, 2]
                
                # 只考虑相机前方的点
                front_mask = vertices_depths > 1e-6
                if not np.any(front_mask):
                    continue
                
                vertices_cam_front = vertices_cam[front_mask]
                vertices_depths_front = vertices_depths[front_mask]
                
                # 投影到2D
                vert_normalized = vertices_cam_front / vertices_depths_front[:, np.newaxis]
                vert_proj_h = (K @ vert_normalized.T).T
                polygon_vertices_2d = vert_proj_h[:, :2]
                
                if polygon_vertices_2d.shape[0] < 3:
                    continue
                
                polygon_vertices_2d = np.clip(
                    polygon_vertices_2d, 
                    np.iinfo(np.int32).min, 
                    np.iinfo(np.int32).max
                )
                
                u_round = np.round(polygon_vertices_2d[:, 0]).astype(np.int32)
                v_round = np.round(polygon_vertices_2d[:, 1]).astype(np.int32)
                valid_uv_mask = (u_round >= 0) & (u_round < W) & (v_round >= 0) & (v_round < H)
                
                if (~valid_uv_mask).all():
                    continue
                
                try:
                    hull = cv2.convexHull(polygon_vertices_2d.astype(np.float32), returnPoints=True)
                    if hull is None or hull.shape[0] < 3:
                        continue
                    hull_points = hull.reshape(-1, 2)
                except Exception:
                    continue
                
                u_min = float(np.min(hull_points[:, 0]))
                u_max = float(np.max(hull_points[:, 0]))
                v_min = float(np.min(hull_points[:, 1]))
                v_max = float(np.max(hull_points[:, 1]))
                
                overlap_u_min = max(u_min, 0.0)
                overlap_u_max = min(u_max, float(W))
                overlap_v_min = max(v_min, 0.0)
                overlap_v_max = min(v_max, float(H))
                
                if overlap_u_max <= overlap_u_min or overlap_v_max <= overlap_v_min:
                    continue
                
                center_u = (overlap_u_min + overlap_u_max) / 2.0
                center_v = (overlap_v_min + overlap_v_max) / 2.0
                width = overlap_u_max - overlap_u_min
                height = overlap_v_max - overlap_v_min
                
                u_min_scaled = center_u - width * scale_factor / 2.0
                u_max_scaled = center_u + width * scale_factor / 2.0
                v_min_scaled = center_v - height * scale_factor / 2.0
                v_max_scaled = center_v + height * scale_factor / 2.0
                
                u_min_clipped = max(0, int(np.floor(u_min_scaled)))
                u_max_clipped = min(W - 1, int(np.ceil(u_max_scaled)))
                v_min_clipped = max(0, int(np.floor(v_min_scaled)))
                v_max_clipped = min(H - 1, int(np.ceil(v_max_scaled)))
                
                if (u_min_clipped >= 0 and u_max_clipped < W and 
                    v_min_clipped >= 0 and v_max_clipped < H and
                    u_min_clipped < u_max_clipped and 
                    v_min_clipped < v_max_clipped):
                    area_image[v_min_clipped:v_max_clipped+1, u_min_clipped:u_max_clipped+1] = 255
            
            result[cam_name] = area_image
        
        return result
    
    @staticmethod
    def _compute_density_from_depth(
            depth_image: np.ndarray,
            sigma_px: float = 2.0,
            occupancy_thresh: float = 0.05,
            sigmoid_alpha: float = 10.0,
    ) -> np.ndarray:
        """
        根据深度图快速近似计算投影点密度图（Projected Point Density）。
        
        近似思路：
        1. 先将深度图转为占用图（有深度记为 1），可视为 2D 平面上的投影点脉冲。
        2. 对占用图做高斯卷积，相当于对所有投影点施加高斯核的 KDE。
        3. 对结果做归一化和 Sigmoid 非线性映射，增强对比度，并抑制过低密度区域。
        """
        if depth_image is None or depth_image.size == 0:
            return np.zeros_like(depth_image, dtype=np.uint8)
        
        # 占用图：仅关心是否有点命中
        occ = (depth_image > 0).astype(np.float32)
        if occ.sum() <= 0:
            return np.zeros_like(depth_image, dtype=np.uint8)
        
        # 高斯模糊近似 2D KDE，ksize=(0,0) 让 OpenCV 根据 sigma 自动选择核大小
        density = cv2.GaussianBlur(
            occ,
            ksize=(0, 0),
            sigmaX=sigma_px,
            sigmaY=sigma_px,
            borderType=cv2.BORDER_REPLICATE,
        )
        
        max_val = float(density.max())
        if max_val <= 1e-6:
            return np.zeros_like(depth_image, dtype=np.uint8)
        
        # 归一化到 [0, 1]
        density_norm = density / max_val
        
        # 对极稀疏区域做截断
        low_mask = density_norm < occupancy_thresh
        density_norm[low_mask] = 0.0
        
        # Sigmoid 非线性增强对比度
        shifted = density_norm - occupancy_thresh
        density_sigmoid = 1.0 / (1.0 + np.exp(-sigmoid_alpha * shifted))
        density_sigmoid[low_mask] = 0.0
        
        density_u8 = (np.clip(density_sigmoid, 0.0, 1.0) * 255).astype(np.uint8)
        return density_u8

    def _render_layered(
            self,
            points_world: np.ndarray,
            colors: np.ndarray,
            point_labels: np.ndarray,
            target_cam_name: str,
            K: np.ndarray,
            camera_to_world: np.ndarray,
            H: int,
            W: int,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        if points_world.shape[0] == 0:
            return (
                np.zeros((H, W, 3), dtype=np.uint8),
                np.zeros((H, W), dtype=np.float32),
                np.zeros((H, W), dtype=np.uint8)
            )

        name_to_idx = CameraRigSpec.default_name_to_idx()
        primary_idx = int(name_to_idx.get(target_cam_name))
        cam_ids_np = np.asarray(point_labels)
        primary_mask = (cam_ids_np == primary_idx)

        adjacent_cameras = CameraRigSpec.get_adjacent_cameras().get(target_cam_name, [])
        adjacent_indices = [name_to_idx.get(adj_cam) for adj_cam in adjacent_cameras]
        secondary_mask = np.isin(cam_ids_np, adjacent_indices)

        pts_primary = points_world[primary_mask]
        cols_primary = colors[primary_mask]
        pts_secondary = points_world[secondary_mask]
        cols_secondary = colors[secondary_mask]

        rgb_primary, depth_primary = self.projector.project_to_image(
            pts_primary, cols_primary, K, camera_to_world, H, W
        )

        if pts_secondary.shape[0] > 0:
            rgb_secondary, depth_secondary = self.projector.project_to_image(
                pts_secondary, cols_secondary, K, camera_to_world, H, W
            )
        else:
            rgb_secondary = np.zeros_like(rgb_primary)
            depth_secondary = np.zeros_like(depth_primary)

        final_rgb = rgb_primary.copy()
        final_depth = depth_primary.copy()
        primary_valid = depth_primary > 0
        secondary_valid = depth_secondary > 0
        fill_mask = (~primary_valid) & secondary_valid
        if np.any(fill_mask):
            final_rgb[fill_mask] = rgb_secondary[fill_mask]
            final_depth[fill_mask] = depth_secondary[fill_mask]

        camera_source_mask = np.zeros((H, W), dtype=np.uint8)
        camera_source_mask[primary_valid] = 255  # 来自主要相机的区域
        camera_source_mask[fill_mask] = 160  # 来自其他相机的区域

        return final_rgb, final_depth, camera_source_mask

    @dataclass
    class NovelViewData:
        cam_name: str
        rgb_image: np.ndarray
        depth_image: np.ndarray
        K: np.ndarray
        camera_to_world: np.ndarray
        camera_source_mask: Optional[np.ndarray] = None

    def _render_novel_views(
            self,
            points_world: np.ndarray,
            colors: np.ndarray,
            camera_ids: np.ndarray,
            rig: 'CameraRigSpec',
            override_cam_c2w: Dict[str, np.ndarray],
            target_cams: List[str],
            object_processor: Optional['ObjectInfoProcessor'],
            frame_idx_int: Optional[int],
            object_ids_to_remove: Optional[List[str]],
            scale_factor: float = 1.1,
            r_px: float = 3.0,  # 渲染器使用的最大像素半径
    ) -> List['PointCloudProcessor.NovelViewData']:
        results: List[PointCloudProcessor.NovelViewData] = []

        points_world_h = np.hstack([points_world, np.ones((points_world.shape[0], 1), dtype=np.float32)])
        num_points = points_world.shape[0]

        for cam_name in target_cams:
            K_orig = rig.get_K(cam_name)
            H_orig, W_orig = rig.get_size(cam_name)
            K = self.config.scale_intrinsics(K_orig, (H_orig, W_orig))
            H, W = self.config.get_target_resolution((H_orig, W_orig))
            new_c2w = override_cam_c2w.get(cam_name)
            new_w2c = np.linalg.inv(new_c2w).astype(np.float32)
            keep_mask = np.ones(num_points, dtype=bool)

            do_removal = (
                    object_processor is not None and
                    object_ids_to_remove is not None and len(object_ids_to_remove) > 0 and
                    frame_idx_int is not None
            )

            if do_removal:
                frame_key = f"{frame_idx_int * 3:06d}.all_object_info.json"
                all_frame_objects = object_processor.get_object_info().get(frame_key, {})
                if not isinstance(all_frame_objects, dict) or len(all_frame_objects) == 0:
                    do_removal = False

            if do_removal:
                pts_world_filtered, cols_filtered, cam_ids_filtered = self._apply_hole_carving(
                    points_world=points_world,
                    colors=colors,
                    camera_ids=np.asarray(camera_ids),
                    K=K,
                    camera_to_world=new_c2w,
                    H=H,
                    W=W,
                    object_processor=object_processor,
                    frame_idx_int=frame_idx_int,
                    object_ids_to_remove=object_ids_to_remove,
                    scale_factor=scale_factor,
                    r_px=r_px,
                )
            else:
                pts_world_filtered = points_world
                cols_filtered = colors
                cam_ids_filtered = camera_ids

            final_rgb, final_depth, camera_source_mask = self._render_layered(
                points_world=pts_world_filtered,
                colors=cols_filtered,
                point_labels=cam_ids_filtered,
                target_cam_name=cam_name,
                K=K,
                camera_to_world=new_c2w,
                H=H,
                W=W,
            )
            results.append(
                PointCloudProcessor.NovelViewData(
                    cam_name=cam_name,
                    rgb_image=final_rgb,
                    depth_image=final_depth,
                    K=K,
                    camera_to_world=new_c2w,
                    camera_source_mask=camera_source_mask,
                )
            )
        return results

    def _render_back_projections(
            self,
            novel_view_data: List['PointCloudProcessor.NovelViewData'],
            rig: 'CameraRigSpec',
            base_cam_c2w: Dict[str, np.ndarray],
            target_cams: List[str],
            object_processor: Optional['ObjectInfoProcessor'] = None,
            frame_idx_int: Optional[int] = None,
            object_ids_to_remove: Optional[List[str]] = None,
            scale_factor: float = 1.1,
            r_px: float = 3.0,
    ) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
        all_points_world: List[np.ndarray] = []
        all_colors_world: List[np.ndarray] = []
        all_labels_world: List[np.ndarray] = []
        default_name_to_idx = CameraRigSpec.default_name_to_idx()

        for data in novel_view_data:
            new_points_cam, new_colors_cam = self.projector.depth_to_pointcloud(
                data.rgb_image, data.depth_image, data.K
            )
            new_points_world = self.transform.points_to_world(new_points_cam, data.camera_to_world)
            all_points_world.append(new_points_world)
            all_colors_world.append(new_colors_cam)
            label_idx = int(default_name_to_idx.get(data.cam_name))
            all_labels_world.append(np.full((new_points_world.shape[0],), label_idx, dtype=np.int32))

        merged_points = (
            np.concatenate(all_points_world, axis=0) if len(all_points_world) > 0 else np.zeros((0, 3),
                                                                                                dtype=np.float32)
        )
        merged_colors = (
            np.concatenate(all_colors_world, axis=0) if len(all_colors_world) > 0 else np.zeros((0, 3),
                                                                                                dtype=np.float32)
        )
        merged_labels = (
            np.concatenate(all_labels_world, axis=0) if len(all_labels_world) > 0 else np.zeros((0,), dtype=np.int32)
        )

        back_projected_images: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]] = {}
        for cam_name in target_cams:
            K_orig = rig.get_K(cam_name)
            H_orig, W_orig = rig.get_size(cam_name)
            K = self.config.scale_intrinsics(K_orig, (H_orig, W_orig))
            H, W = self.config.get_target_resolution((H_orig, W_orig))
            cam_c2w_orig = base_cam_c2w.get(cam_name)
            if cam_c2w_orig is None:
                continue
            # 如果需要在反投影阶段挖洞，则使用共享方法
            points_world_to_use = merged_points
            colors_to_use = merged_colors
            labels_to_use = merged_labels

            do_removal = (
                object_processor is not None and
                object_ids_to_remove is not None and len(object_ids_to_remove) > 0 and
                frame_idx_int is not None
            )
            if do_removal:
                points_world_to_use, colors_to_use, labels_to_use = self._apply_hole_carving(
                    points_world=merged_points,
                    colors=merged_colors,
                    camera_ids=merged_labels,
                    K=K,
                    camera_to_world=cam_c2w_orig,
                    H=H,
                    W=W,
                    object_processor=object_processor,
                    frame_idx_int=frame_idx_int,
                    object_ids_to_remove=object_ids_to_remove,
                    scale_factor=scale_factor,
                    r_px=r_px,
                )

            back_rgb_img, back_depth_img, back_camera_source_mask = self._render_layered(
                points_world=points_world_to_use,
                colors=colors_to_use,
                point_labels=labels_to_use,
                target_cam_name=cam_name,
                K=K,
                camera_to_world=cam_c2w_orig,
                H=H,
                W=W,
            )
            back_mask = (back_depth_img > 0)
            back_projected_images[cam_name] = (back_rgb_img, back_depth_img, back_mask, back_camera_source_mask)
        return back_projected_images

    def _composite_hybrid_views(
            self,
            lidar_views: List['PointCloudProcessor.NovelViewData'],
            rig: 'CameraRigSpec',
            object_processor: 'ObjectInfoProcessor',
            frame_idx_int: int,
            objects_to_add_info: List[Tuple[str, Path]],
    ) -> List['PointCloudProcessor.NovelViewData']:
        """
        执行 Pass 2 (3DGS) 和 Pass 3 (合成)。
        """

        frame_key = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
        all_frame_objects: Dict[str, Any] = object_processor.get_object_info().get(frame_key, {})  # type: ignore
        if not isinstance(all_frame_objects, dict):
            return lidar_views

        # 组装 3DGS 渲染输入
        objects_to_render: List[Dict[str, Any]] = []
        for obj_id, ply_path in objects_to_add_info:
            info = all_frame_objects.get(str(obj_id))
            if info is None:
                continue
            objects_to_render.append({
                'id': str(obj_id),
                'ply_path': Path(ply_path),
                'object_lwh': np.asarray(info['object_lwh'], dtype=np.float32),
                'object_to_world': np.asarray(info['object_to_world'], dtype=np.float32),
            })

        if len(objects_to_render) == 0:
            return lidar_views

        final_composited_views: List[PointCloudProcessor.NovelViewData] = []
        for lidar_view in lidar_views:
            cam_name = lidar_view.cam_name
            K = lidar_view.K
            H, W = lidar_view.rgb_image.shape[:2]
            c2w = lidar_view.camera_to_world
            w2c = np.linalg.inv(c2w).astype(np.float32)

            try:
                dist_coeffs = rig.get_dist_coeffs(cam_name)
            except Exception:
                dist_coeffs = np.zeros(5, dtype=np.float32)

            # Pass 2: 3DGS 渲染
            fg_rgb_u8, fg_alpha_u8, fg_depth_f32 = Gaussian.render_objects_for_camera(
                device=self._device,
                camera_pose_w2c=w2c,
                K=K,
                dist_coeffs=dist_coeffs,
                W=W,
                H=H,
                objects_to_render=objects_to_render,
            )

            # Pass 3: 合成（深度测试 + alpha 混合）
            bg_rgb_f32 = lidar_view.rgb_image.astype(np.float32) / 255.0
            bg_depth_f32 = lidar_view.depth_image.astype(np.float32)

            fg_rgb_f32 = fg_rgb_u8.astype(np.float32) / 255.0
            fg_alpha_f32 = (fg_alpha_u8.astype(np.float32) / 255.0)[..., None]

            fg_has_depth = fg_depth_f32 > 0
            bg_has_depth = bg_depth_f32 > 0
            hole_mask = ~bg_has_depth
            draw_fg_mask = fg_has_depth & hole_mask

            composited_rgb_f32 = bg_rgb_f32.copy()
            composited_depth_f32 = bg_depth_f32.copy()

            if np.any(draw_fg_mask):
                a = fg_alpha_f32[draw_fg_mask]
                fg_c = fg_rgb_f32[draw_fg_mask]
                bg_c = bg_rgb_f32[draw_fg_mask]
                composited_rgb_f32[draw_fg_mask] = fg_c * a + bg_c * (1.0 - a)
                composited_depth_f32[draw_fg_mask] = fg_depth_f32[draw_fg_mask]

            composited_rgb_u8 = (np.clip(composited_rgb_f32, 0.0, 1.0) * 255).astype(np.uint8)

            if lidar_view.camera_source_mask is not None:
                camera_source_mask = lidar_view.camera_source_mask.copy()
            else:
                H, W = lidar_view.rgb_image.shape[:2]
                camera_source_mask = np.zeros((H, W), dtype=np.uint8)
                camera_source_mask[bg_depth_f32 > 0] = 255

            if np.any(draw_fg_mask):
                camera_source_mask[draw_fg_mask] = 80  # 80 for the foreground object
            
            final_composited_views.append(
                PointCloudProcessor.NovelViewData(
                    cam_name=cam_name,
                    rgb_image=composited_rgb_u8,
                    depth_image=composited_depth_f32,
                    K=K,
                    camera_to_world=c2w,
                    camera_source_mask=camera_source_mask,
                )
            )

        return final_composited_views

    def process_file(
            self,
            pointcloud_path: Path,
            output_dir: Path,
            rig: 'CameraRigSpec',
            base_cam_c2w: Dict[str, np.ndarray],
            override_cam_c2w: Dict[str, np.ndarray],
            cams: Optional[List[str]] = None,
            front_cam_name: str = 'front',
            frame_idx_int: Optional[int] = None,
            object_processor: Optional['ObjectInfoProcessor'] = None,
            clip_config: Optional['ClipConfig'] = None,
            visualize: bool = True,
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray]]:
        from .utils import load_point_cloud_efficient, POINT_CLOUD_STORAGE_FORMAT

        do_back_project = clip_config.do_back_project
        object_ids_to_remove = clip_config.object_ids_to_remove
        objects_to_add_info = clip_config.objects_to_add_info

        target_cams = rig.camera_names if (cams is None or cams == ['all']) else cams
        frame_id_str = f"{frame_idx_int:06d}"

        # 检查 3DGS 是否需要运行
        have_3dgs = (
                self._gs_renderer and
                object_processor is not None and
                objects_to_add_info is not None and len(objects_to_add_info) > 0 and
                frame_idx_int is not None
        )

        # Lidar 渲染路径
        pts_colors = load_point_cloud_efficient(
            pointcloud_path, format_type=POINT_CLOUD_STORAGE_FORMAT, return_labels=True
        )
        assert len(pts_colors) == 3, "Point cloud should contain points, colors, and camera IDs."
        points, colors, camera_ids = pts_colors
        if colors.max() > 1.0:
            colors = np.clip(colors / 255.0, 0.0, 1.0)

        if front_cam_name not in base_cam_c2w:
            raise RuntimeError('缺少 front 相机基线外参与点云坐标系转换。')
        front_c2w_base = base_cam_c2w[front_cam_name]
        cam_to_ego_front = rig.get_cam_to_ego(front_cam_name)
        ego_to_world = (front_c2w_base @ np.linalg.inv(cam_to_ego_front)).astype(np.float32)

        if override_cam_c2w is None or len(override_cam_c2w) == 0:
            raise RuntimeError('未提供 override_cam_c2w，配置驱动模式需要为各相机提供外参覆写。')

        # --- Pass 1: Lidar 背景渲染 ---
        points_world = self.transform.points_to_world(points, ego_to_world)
        novel_views_lidar = self._render_novel_views(
            points_world=points_world,
            colors=colors,
            camera_ids=np.asarray(camera_ids),
            rig=rig,
            override_cam_c2w=override_cam_c2w,
            target_cams=target_cams,
            object_processor=object_processor,
            frame_idx_int=frame_idx_int,
            object_ids_to_remove=None if do_back_project else object_ids_to_remove,
        )

        if not do_back_project:
            # --- [流程 1: 不反投影] ---
            # 1. 渲染 LiDAR (已在 novel_views_lidar 中, V12已打孔)
            # 2. 渲染 3DGS 并合成
            final_views = novel_views_lidar
            if have_3dgs:
                final_views = self._composite_hybrid_views(
                    lidar_views=novel_views_lidar,
                    rig=rig,
                    object_processor=object_processor,  # type: ignore
                    frame_idx_int=frame_idx_int,  # type: ignore
                    objects_to_add_info=objects_to_add_info,  # type: ignore
                )

            # 3. 保存最终合成结果
            rgb_frames_dict: Dict[str, np.ndarray] = {}
            mask_frames_dict: Dict[str, np.ndarray] = {}
            density_frames_dict: Dict[str, np.ndarray] = {}
            for data in final_views:
                if visualize:
                    self.saver.save_rgb_jpg(output_dir / data.cam_name, data.rgb_image, frame_id_str)
                
                if data.camera_source_mask is not None:
                    mask_img = data.camera_source_mask.copy()
                    mask_img[data.depth_image == 0] = 0
                else:
                    mask_img = (data.depth_image > 0).astype(np.uint8) * 255
                
                # 反向mask：255 - mask
                mask_img = 255 - mask_img
                
                if visualize:
                    self.saver.save_mask_png(output_dir / data.cam_name, mask_img, frame_id_str)
                rgb_frames_dict[data.cam_name] = data.rgb_image.copy()
                mask_frames_dict[data.cam_name] = mask_img.copy()
                
                # 基于最终深度图计算投影点密度图
                density_img = self._compute_density_from_depth(data.depth_image)
                density_frames_dict[data.cam_name] = density_img
            
            # 4. 计算3dbox的2d投影区域（在点云删除逻辑后）
            area_frames_dict = self._compute_object_2d_bboxes(
                rig=rig,
                camera_to_world=override_cam_c2w,
                object_processor=object_processor,
                frame_idx_int=frame_idx_int,
                object_ids_to_remove=object_ids_to_remove,
                objects_to_add_info=objects_to_add_info,
                scale_factor=1.2,
            )
            
            return rgb_frames_dict, mask_frames_dict, area_frames_dict, density_frames_dict

        else:
            # --- [流程 2: 需要反投影] ---
            # 1. 渲染 LiDAR (已在 novel_views_lidar 中, V12已打孔)
            # 2. 反投影 LiDAR，得到降质的背景 (在 base_cam_c2w)
            back_projected_map = self._render_back_projections(
                novel_view_data=novel_views_lidar,
                rig=rig,
                base_cam_c2w=base_cam_c2w,
                target_cams=target_cams,
                object_processor=object_processor,
                frame_idx_int=frame_idx_int,
                object_ids_to_remove=object_ids_to_remove,
            )

            # 3. 准备 3DGS 渲染
            objects_to_render: List[Dict[str, Any]] = []
            if have_3dgs:
                frame_key = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
                all_frame_objects: Dict[str, Any] = object_processor.get_object_info().get(frame_key,
                                                                                           {})  # type: ignore
                if isinstance(all_frame_objects, dict):
                    for obj_id, ply_path in objects_to_add_info:
                        info = all_frame_objects.get(str(obj_id))
                        if info is None:
                            continue
                        objects_to_render.append({
                            'id': str(obj_id),
                            'ply_path': Path(ply_path),
                            'object_lwh': np.asarray(info['object_lwh'], dtype=np.float32),
                            'object_to_world': np.asarray(info['object_to_world'], dtype=np.float32),
                        })

            # 4. 渲染 3DGS (在 base_cam_c2w) 并合成
            rgb_frames_dict: Dict[str, np.ndarray] = {}
            mask_frames_dict: Dict[str, np.ndarray] = {}
            density_frames_dict: Dict[str, np.ndarray] = {}
            for cam_name in target_cams:
                bg_data = back_projected_map.get(cam_name)
                if bg_data is None:
                    continue

                bg_rgb_u8, bg_depth_f32, bg_mask, bg_camera_source_mask = bg_data

                final_rgb_u8 = bg_rgb_u8
                final_depth_f32 = bg_depth_f32

                if have_3dgs and len(objects_to_render) > 0:
                    # 获取 *原始* (base) 相机参数
                    K_orig = rig.get_K(cam_name)
                    H_orig, W_orig = rig.get_size(cam_name)
                    K = self.config.scale_intrinsics(K_orig, (H_orig, W_orig))
                    H, W = self.config.get_target_resolution((H_orig, W_orig))
                    c2w = base_cam_c2w.get(cam_name)
                    w2c = np.linalg.inv(c2w).astype(np.float32)

                    try:
                        dist_coeffs = rig.get_dist_coeffs(cam_name)
                    except Exception:
                        dist_coeffs = np.zeros(5, dtype=np.float32)

                    # 渲染 3DGS 前景
                    fg_rgb_u8, fg_alpha_u8, fg_depth_f32 = Gaussian.render_objects_for_camera(
                        device=self._device,
                        camera_pose_w2c=w2c,
                        K=K,
                        dist_coeffs=dist_coeffs,
                        W=W,
                        H=H,
                        objects_to_render=objects_to_render,
                    )

                    # 合成
                    bg_rgb_f32 = bg_rgb_u8.astype(np.float32) / 255.0
                    fg_rgb_f32 = fg_rgb_u8.astype(np.float32) / 255.0
                    fg_alpha_f32 = (fg_alpha_u8.astype(np.float32) / 255.0)[..., None]

                    fg_has_depth = fg_depth_f32 > 0
                    bg_has_depth = bg_depth_f32 > 0

                    hole_mask = ~bg_has_depth
                    draw_fg_mask = fg_has_depth & hole_mask

                    composited_rgb_f32 = bg_rgb_f32.copy()
                    composited_depth_f32 = bg_depth_f32.copy()

                    if np.any(draw_fg_mask):
                        a = fg_alpha_f32[draw_fg_mask]
                        fg_c = fg_rgb_f32[draw_fg_mask]
                        bg_c = bg_rgb_f32[draw_fg_mask]
                        composited_rgb_f32[draw_fg_mask] = fg_c * a + bg_c * (1.0 - a)
                        composited_depth_f32[draw_fg_mask] = fg_depth_f32[draw_fg_mask]

                    final_rgb_u8 = (np.clip(composited_rgb_f32, 0.0, 1.0) * 255).astype(np.uint8)
                    final_depth_f32 = composited_depth_f32

                # 5. 保存最终结果
                if bg_camera_source_mask is not None:
                    final_mask = bg_camera_source_mask.copy()
                    final_mask[final_depth_f32 == 0] = 0
                else:
                    H, W = final_depth_f32.shape
                    final_mask = np.zeros((H, W), dtype=np.uint8)
                    final_mask[final_depth_f32 > 0] = 255

                if have_3dgs and len(objects_to_render) > 0:
                    fg_has_depth = fg_depth_f32 > 0
                    bg_has_depth = bg_depth_f32 > 0
                    hole_mask = ~bg_has_depth
                    draw_fg_mask = fg_has_depth & hole_mask
                    final_mask[draw_fg_mask] = 80
                
                # 反向mask：255 - mask
                final_mask = 255 - final_mask
                
                if visualize:
                    self.saver.save_rgb_jpg(output_dir / cam_name, final_rgb_u8, frame_id_str)
                    self.saver.save_mask_png(output_dir / cam_name, final_mask, frame_id_str)
                rgb_frames_dict[cam_name] = final_rgb_u8.copy()
                mask_frames_dict[cam_name] = final_mask.copy()
                
                # 基于最终深度图计算投影点密度图
                density_img = self._compute_density_from_depth(final_depth_f32)
                density_frames_dict[cam_name] = density_img
            
            # 4. 计算3dbox的2d投影区域（在点云删除逻辑后）
            area_frames_dict = self._compute_object_2d_bboxes(
                rig=rig,
                camera_to_world=base_cam_c2w,
                object_processor=object_processor,
                frame_idx_int=frame_idx_int,
                object_ids_to_remove=object_ids_to_remove,
                objects_to_add_info=objects_to_add_info,
                scale_factor=1.2,
            )
            
            return rgb_frames_dict, mask_frames_dict, area_frames_dict, density_frames_dict