from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Any

import numpy as np
from scipy.spatial.transform import Rotation as R
import torch
import cv2

from .camera import CameraRigSpec, CameraTransform
from .projection import ProjectionConfig, PointCloudProjector
from .utils import ImageSaver
from .gaussian_model import Gaussian
from .edit_processor import EditProcessor, EditConfig

if TYPE_CHECKING:
    from .object_info import ObjectInfoProcessor
    from .clip_config import ClipConfig

# 与参考渲染逻辑一致：Z 轴 180° 旋转修正
ROTATION_FIX_Z_180 = R.from_euler('z', 180, degrees=True).as_matrix()
TRANSFORM_FIX_Z_180 = np.eye(4, dtype=np.float32)
TRANSFORM_FIX_Z_180[:3, :3] = ROTATION_FIX_Z_180


class PointCloudProcessor:
    def __init__(self, config: ProjectionConfig, projector: Optional[PointCloudProjector] = None):
        self.config = config
        self.projector = projector or PointCloudProjector(config)
        self.transform = CameraTransform()
        self.saver = ImageSaver()
        # 3DGS 渲染开关与设备
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._gs_renderer = True


    def _prepare_objects_to_render(
            self,
            clip_config: 'ClipConfig',
            all_frame_objects: Dict[str, Any],
            objects_to_add_info: List[Tuple[str, Path]],
            insertion_info: Optional[Dict[str, Any]] = None,
            frame_idx_int: Optional[int] = None,
    ) -> List[Dict[str, Any]]:
        """
        统一构建 3DGS 渲染所需的 objects_to_render 列表，并在其中注入各类编辑。
        """
        objects_to_render: List[Dict[str, Any]] = []
        edit_config = EditProcessor.extract_edit_config(clip_config)

        # 若提供了 insertion_info，从 all_object_info_insertion 中读取插入目标位置
        frame_key_for_insertion: Optional[str] = None
        if insertion_info is not None and frame_idx_int is not None:
            frame_key_for_insertion = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
        insertion_frame_objs: Dict[str, Any] = {}
        if frame_key_for_insertion is not None:
            maybe_objs = insertion_info.get(frame_key_for_insertion, {})
            if isinstance(maybe_objs, dict):
                insertion_frame_objs = maybe_objs

        for obj_id, ply_path in objects_to_add_info:
            obj_id_str = str(obj_id)

            # Insertion: 位置来自 all_object_info_insertion["insertion_0"]
            if edit_config.insertion_new_id is not None and obj_id_str == edit_config.insertion_new_id and len(insertion_frame_objs) > 0:
                insertion_obj_info = insertion_frame_objs.get("insertion_0")
                if insertion_obj_info is None:
                    continue
                if not Path(ply_path).exists():
                    continue
                object_to_world = np.asarray(insertion_obj_info['object_to_world'], dtype=np.float32)
                lwh = np.asarray(insertion_obj_info['object_lwh'], dtype=np.float32)
                objects_to_render.append({
                    'id': obj_id_str,
                    'ply_path': Path(ply_path),
                    'object_lwh': lwh,
                    'object_to_world': object_to_world,
                })
                continue

            # 其它对象（replacement / repositioning）：位置来自 all_object_info
            info = all_frame_objects.get(obj_id_str)
            if info is None:
                print(f"[DEBUG objects_to_render] frame={frame_idx_int}, id={obj_id_str} 在 all_frame_objects 中不存在，跳过。")
                continue

            object_to_world = np.asarray(info['object_to_world'], dtype=np.float32)

            # Repositioning: 根据 action_for_reposition 做平移
            if edit_config.reposition_origin_id is not None and obj_id_str == edit_config.reposition_origin_id:
                object_to_world = EditProcessor.shift_object_tfm_by_action(
                    object_to_world,
                    edit_config.action_for_reposition,
                    distance_m=EditProcessor.REPOSITION_DISTANCE_M,
                )

            objects_to_render.append({
                'id': obj_id_str,
                'ply_path': Path(ply_path),
                'object_lwh': np.asarray(info['object_lwh'], dtype=np.float32),
                'object_to_world': object_to_world,
            })

        return objects_to_render
    def _apply_hole_carving(
            self,
            points_world: np.ndarray,
            colors: np.ndarray,
            camera_ids: np.ndarray,
            K: np.ndarray,
            camera_to_world: np.ndarray,
            H: int,
            W: int,
            object_processor: Optional['ObjectInfoProcessor'],
            frame_idx_int: Optional[int],
            object_ids_to_remove: Optional[List[str]],
            scale_factor: float = 1.1,
            r_px: float = 3.0,
            clip_config: Optional['ClipConfig'] = None,
            insertion_info: Optional[Dict[str, Any]] = None,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        根据给定相机姿态与目标 OBB，在投影平面上计算阴影并删除相应点（挖洞）。
        返回过滤后的 points/colors/camera_ids。
        """
        if frame_idx_int is None or points_world.shape[0] == 0:
            return points_world, colors, camera_ids

        edit_config = EditProcessor.extract_edit_config(clip_config) if clip_config else None
        if edit_config is None:
            edit_config = EditConfig()

        # 判断是否需要挖洞
        has_ids_to_remove = object_ids_to_remove is not None and len(object_ids_to_remove) > 0
        has_insertion = edit_config.insertion_do_hole and edit_config.insertion_new_id is not None and insertion_info is not None

        if not (object_processor is not None and (has_ids_to_remove or has_insertion)):
            return points_world, colors, camera_ids

        w2c = np.linalg.inv(camera_to_world).astype(np.float32)
        num_points = points_world.shape[0]
        points_world_h = np.hstack([points_world, np.ones((num_points, 1), dtype=np.float32)])
        points_cam = (w2c @ points_world_h.T).T[:, :3]
        points_depths = points_cam[:, 2]
        valid_depth_mask = points_depths > 1e-6
        points_2d_all = np.zeros((num_points, 2), dtype=np.float32)
        if np.any(valid_depth_mask):
            points_cam_valid = points_cam[valid_depth_mask]
            points_depths_valid = points_depths[valid_depth_mask]
            points_normalized = points_cam_valid / points_depths_valid[:, np.newaxis]
            points_proj_h = (K @ points_normalized.T).T
            points_2d_all[valid_depth_mask] = points_proj_h[:, :2]
        points_2d_int = np.floor(points_2d_all).astype(np.int32)
        final_remove_mask_for_cam = np.zeros(num_points, dtype=bool)

        frame_key = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
        all_frame_objects = object_processor.get_object_info().get(frame_key, {})  # type: ignore
        if not isinstance(all_frame_objects, dict):
            all_frame_objects = {}

        # 统一的挖洞处理函数
        def _carve_hole_for_object(object_to_world: np.ndarray, lwh: np.ndarray) -> None:
            """为单个对象挖洞"""
            remove_mask_2d, min_depth, polygon_vertices_2d = EditProcessor.compute_hole_mask_for_object(
                object_to_world, lwh, w2c, K, H, W, scale_factor
            )
            if remove_mask_2d is None or polygon_vertices_2d is None:
                return

            mask_depth_test = (valid_depth_mask) & (points_depths > min_depth - 0.2)
            if not np.any(mask_depth_test):
                return

            mask_bounds_x_f = (points_2d_all[:, 0] >= (-0.5 - r_px)) & (points_2d_all[:, 0] < ((W - 0.5) + r_px))
            mask_bounds_y_f = (points_2d_all[:, 1] >= (-0.5 - r_px)) & (points_2d_all[:, 1] < ((H - 0.5) + r_px))
            mask_in_bounds_f = mask_bounds_x_f & mask_bounds_y_f
            mask_to_check = mask_depth_test & mask_in_bounds_f
            if not np.any(mask_to_check):
                return

            # 计算多边形边界框
            poly_min_x = float(np.min(polygon_vertices_2d[:, 0]))
            poly_max_x = float(np.max(polygon_vertices_2d[:, 0]))
            poly_min_y = float(np.min(polygon_vertices_2d[:, 1]))
            poly_max_y = float(np.max(polygon_vertices_2d[:, 1]))

            in_rect = (
                (points_2d_all[:, 0] >= (poly_min_x - r_px)) & (points_2d_all[:, 0] <= (poly_max_x + r_px)) &
                (points_2d_all[:, 1] >= (poly_min_y - r_px)) & (points_2d_all[:, 1] <= (poly_max_y + r_px))
            )

            mask_rect = mask_to_check & in_rect
            if not np.any(mask_rect):
                return

            indices_to_check = np.where(mask_rect)[0]
            coords_to_check = points_2d_int[indices_to_check]
            x_idx = np.clip(coords_to_check[:, 0], 0, W - 1)
            y_idx = np.clip(coords_to_check[:, 1], 0, H - 1)
            mask_values = remove_mask_2d[y_idx, x_idx]
            inside_by_mask = (mask_values == 255)

            not_inside = ~inside_by_mask
            if np.any(not_inside):
                coords_f = points_2d_all[indices_to_check][not_inside]
                shadow_polygon = cv2.convexHull(polygon_vertices_2d, returnPoints=True)
                if shadow_polygon is not None and shadow_polygon.shape[0] >= 3:
                    hull2d = shadow_polygon.reshape(-1, 2)
                    dist_ok = np.fromiter(
                        (cv2.pointPolygonTest(hull2d, (float(x), float(y)), True) >= -r_px for x, y in coords_f),
                        dtype=bool,
                        count=coords_f.shape[0]
                    )
                    inside_by_mask[not_inside] = dist_ok

            remove_indices = indices_to_check[inside_by_mask]
            final_remove_mask_for_cam[remove_indices] = True

        # Deletion / Replacement / Repositioning 目标挖洞
        if has_ids_to_remove and len(all_frame_objects) > 0:
            ids_for_removal: List[str] = list(object_ids_to_remove)
            if edit_config.reposition_origin_id is not None:
                ids_for_removal.append(edit_config.reposition_origin_id)

            for obj_id in ids_for_removal:
                obj_info = all_frame_objects.get(str(obj_id))
                if obj_info is None:
                    continue

                base_object_to_world = np.array(obj_info['object_to_world'], dtype=np.float32)
                lwh = np.array(obj_info['object_lwh'], dtype=np.float32)

                # 挖掉原始位置
                _carve_hole_for_object(base_object_to_world, lwh)

                # Repositioning: 在新位置也挖一次
                if edit_config.reposition_origin_id is not None and str(obj_id) == edit_config.reposition_origin_id:
                    shifted_tfm = EditProcessor.shift_object_tfm_by_action(
                        base_object_to_world,
                        edit_config.action_for_reposition,
                        distance_m=EditProcessor.REPOSITION_DISTANCE_M,
                    )
                    _carve_hole_for_object(shifted_tfm, lwh)

        # Insertion 目标挖洞
        if has_insertion:
            frame_key_ins = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
            insertion_frame_objs = insertion_info.get(frame_key_ins, {})
            if isinstance(insertion_frame_objs, dict) and "insertion_0" in insertion_frame_objs:
                insertion_obj_info = insertion_frame_objs["insertion_0"]
                try:
                    object_to_world_ins = np.asarray(insertion_obj_info['object_to_world'], dtype=np.float32)
                    lwh_ins = np.asarray(insertion_obj_info['object_lwh'], dtype=np.float32)
                except Exception:
                    object_to_world_ins = None
                    lwh_ins = None

                if object_to_world_ins is not None and lwh_ins is not None:
                    _carve_hole_for_object(object_to_world_ins, lwh_ins)

        keep_mask = ~final_remove_mask_for_cam
        return points_world[keep_mask], colors[keep_mask], camera_ids[keep_mask]


    def _compute_object_2d_bboxes(
            self,
            rig: 'CameraRigSpec',
            camera_to_world: Dict[str, np.ndarray],
            object_processor: Optional['ObjectInfoProcessor'],
            frame_idx_int: Optional[int],
            object_ids_to_remove: Optional[List[str]],
            objects_to_add_info: Optional[List[Tuple[str, Path]]],
            scale_factor: float = 1.2,
    ) -> Dict[str, np.ndarray]:
        """
        计算通过3DGS添加的新对象在2D图像上的投影矩形区域（扩大scale_factor倍）。
        
        Args:
            rig: 相机配置
            camera_to_world: 每个相机的camera_to_world变换
            object_processor: 对象信息处理器
            frame_idx_int: 帧索引
            object_ids_to_remove: 要删除的对象ID列表（未使用，保留用于兼容性）
            objects_to_add_info: 要添加的对象ID和路径列表（3DGS添加的新对象）
            scale_factor: 扩大倍数（默认1.2）
            
        Returns:
            Dict[str, np.ndarray]: 每个相机对应的area图像，形状为(H, W, 3)，uint8
        """
        import cv2
        
        if object_processor is None or frame_idx_int is None:
            return {}
        
        # 只使用3DGS添加的新对象（被删除和添加的对象是完全一样的）
        target_object_ids = set()
        if objects_to_add_info is not None:
            for obj_id, _ in objects_to_add_info:
                target_object_ids.add(str(obj_id))
        
        if len(target_object_ids) == 0:
            return {}
        
        frame_key = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
        all_frame_objects = object_processor.get_object_info().get(frame_key, {})
        if not isinstance(all_frame_objects, dict) or len(all_frame_objects) == 0:
            return {}
        
        result: Dict[str, np.ndarray] = {}
        
        for cam_name, c2w in camera_to_world.items():
            # 获取相机参数
            K_orig = rig.get_K(cam_name)
            H_orig, W_orig = rig.get_size(cam_name)
            K = self.config.scale_intrinsics(K_orig, (H_orig, W_orig))
            H, W = self.config.get_target_resolution((H_orig, W_orig))
            
            # 创建空白图像
            area_image = np.zeros((H, W, 3), dtype=np.uint8)
            
            # 世界到相机变换
            w2c = np.linalg.inv(c2w).astype(np.float32)
            
            # 只遍历被删除和添加的对象
            for obj_id in target_object_ids:
                obj_info = all_frame_objects.get(str(obj_id))
                if obj_info is None:
                    continue
                object_to_world = np.array(obj_info['object_to_world'], dtype=np.float32)
                lwh = np.array(obj_info['object_lwh'], dtype=np.float32)
                
                # 获取OBB的8个顶点（使用scale_factor=1.0，因为我们后面会单独扩大）
                vertices_local_h = EditProcessor.get_obb_vertices(lwh, scale_factor=1.0)
                vertices_world_h = (object_to_world @ vertices_local_h.T).T
                vertices_cam_h = (w2c @ vertices_world_h.T).T
                vertices_cam = vertices_cam_h[:, :3]
                vertices_depths = vertices_cam[:, 2]
                
                # 只考虑相机前方的点
                front_mask = vertices_depths > 1e-6
                if not np.any(front_mask):
                    continue
                
                vertices_cam_front = vertices_cam[front_mask]
                vertices_depths_front = vertices_depths[front_mask]
                
                # 投影到2D
                vert_normalized = vertices_cam_front / vertices_depths_front[:, np.newaxis]
                vert_proj_h = (K @ vert_normalized.T).T
                polygon_vertices_2d = vert_proj_h[:, :2]
                
                if polygon_vertices_2d.shape[0] < 3:
                    continue
                
                polygon_vertices_2d = np.clip(
                    polygon_vertices_2d, 
                    np.iinfo(np.int32).min, 
                    np.iinfo(np.int32).max
                )
                
                u_round = np.round(polygon_vertices_2d[:, 0]).astype(np.int32)
                v_round = np.round(polygon_vertices_2d[:, 1]).astype(np.int32)
                valid_uv_mask = (u_round >= 0) & (u_round < W) & (v_round >= 0) & (v_round < H)
                
                if (~valid_uv_mask).all():
                    continue
                
                try:
                    hull = cv2.convexHull(polygon_vertices_2d.astype(np.float32), returnPoints=True)
                    if hull is None or hull.shape[0] < 3:
                        continue
                    hull_points = hull.reshape(-1, 2)
                except Exception:
                    continue
                
                u_min = float(np.min(hull_points[:, 0]))
                u_max = float(np.max(hull_points[:, 0]))
                v_min = float(np.min(hull_points[:, 1]))
                v_max = float(np.max(hull_points[:, 1]))
                
                overlap_u_min = max(u_min, 0.0)
                overlap_u_max = min(u_max, float(W))
                overlap_v_min = max(v_min, 0.0)
                overlap_v_max = min(v_max, float(H))
                
                if overlap_u_max <= overlap_u_min or overlap_v_max <= overlap_v_min:
                    continue
                
                center_u = (overlap_u_min + overlap_u_max) / 2.0
                center_v = (overlap_v_min + overlap_v_max) / 2.0
                width = overlap_u_max - overlap_u_min
                height = overlap_v_max - overlap_v_min
                
                u_min_scaled = center_u - width * scale_factor / 2.0
                u_max_scaled = center_u + width * scale_factor / 2.0
                v_min_scaled = center_v - height * scale_factor / 2.0
                v_max_scaled = center_v + height * scale_factor / 2.0
                
                u_min_clipped = max(0, int(np.floor(u_min_scaled)))
                u_max_clipped = min(W - 1, int(np.ceil(u_max_scaled)))
                v_min_clipped = max(0, int(np.floor(v_min_scaled)))
                v_max_clipped = min(H - 1, int(np.ceil(v_max_scaled)))
                
                if (u_min_clipped >= 0 and u_max_clipped < W and 
                    v_min_clipped >= 0 and v_max_clipped < H and
                    u_min_clipped < u_max_clipped and 
                    v_min_clipped < v_max_clipped):
                    area_image[v_min_clipped:v_max_clipped+1, u_min_clipped:u_max_clipped+1] = 255
            
            result[cam_name] = area_image
        
        return result

    def _render_layered(
            self,
            points_world: np.ndarray,
            colors: np.ndarray,
            point_labels: np.ndarray,
            target_cam_name: str,
            K: np.ndarray,
            camera_to_world: np.ndarray,
            H: int,
            W: int,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        if points_world.shape[0] == 0:
            return (
                np.zeros((H, W, 3), dtype=np.uint8),
                np.zeros((H, W), dtype=np.float32),
                np.zeros((H, W), dtype=np.uint8)
            )

        name_to_idx = CameraRigSpec.default_name_to_idx()
        primary_idx = int(name_to_idx.get(target_cam_name))
        cam_ids_np = np.asarray(point_labels)
        primary_mask = (cam_ids_np == primary_idx)

        adjacent_cameras = CameraRigSpec.get_adjacent_cameras().get(target_cam_name, [])
        adjacent_indices = [name_to_idx.get(adj_cam) for adj_cam in adjacent_cameras]
        secondary_mask = np.isin(cam_ids_np, adjacent_indices)

        pts_primary = points_world[primary_mask]
        cols_primary = colors[primary_mask]
        pts_secondary = points_world[secondary_mask]
        cols_secondary = colors[secondary_mask]

        rgb_primary, depth_primary = self.projector.project_to_image(
            pts_primary, cols_primary, K, camera_to_world, H, W
        )

        if pts_secondary.shape[0] > 0:
            rgb_secondary, depth_secondary = self.projector.project_to_image(
                pts_secondary, cols_secondary, K, camera_to_world, H, W
            )
        else:
            rgb_secondary = np.zeros_like(rgb_primary)
            depth_secondary = np.zeros_like(depth_primary)

        final_rgb = rgb_primary.copy()
        final_depth = depth_primary.copy()
        primary_valid = depth_primary > 0
        secondary_valid = depth_secondary > 0
        fill_mask = (~primary_valid) & secondary_valid
        if np.any(fill_mask):
            final_rgb[fill_mask] = rgb_secondary[fill_mask]
            final_depth[fill_mask] = depth_secondary[fill_mask]

        camera_source_mask = np.zeros((H, W), dtype=np.uint8)
        camera_source_mask[primary_valid] = 255  # 来自主要相机的区域
        camera_source_mask[fill_mask] = 160  # 来自其他相机的区域

        return final_rgb, final_depth, camera_source_mask

    @dataclass
    class NovelViewData:
        cam_name: str
        rgb_image: np.ndarray
        depth_image: np.ndarray
        K: np.ndarray
        camera_to_world: np.ndarray
        camera_source_mask: Optional[np.ndarray] = None

    def _render_novel_views(
            self,
            points_world: np.ndarray,
            colors: np.ndarray,
            camera_ids: np.ndarray,
            rig: 'CameraRigSpec',
            override_cam_c2w: Dict[str, np.ndarray],
            target_cams: List[str],
            object_processor: Optional['ObjectInfoProcessor'],
            frame_idx_int: Optional[int],
            object_ids_to_remove: Optional[List[str]],
            scale_factor: float = 1.1,
            r_px: float = 3.0,  # 渲染器使用的最大像素半径
            clip_config: Optional['ClipConfig'] = None,
            insertion_info: Optional[Dict[str, Any]] = None,
    ) -> List['PointCloudProcessor.NovelViewData']:
        results: List[PointCloudProcessor.NovelViewData] = []

        points_world_h = np.hstack([points_world, np.ones((points_world.shape[0], 1), dtype=np.float32)])
        num_points = points_world.shape[0]

        for cam_name in target_cams:
            K_orig = rig.get_K(cam_name)
            H_orig, W_orig = rig.get_size(cam_name)
            K = self.config.scale_intrinsics(K_orig, (H_orig, W_orig))
            H, W = self.config.get_target_resolution((H_orig, W_orig))
            new_c2w = override_cam_c2w.get(cam_name)
            new_w2c = np.linalg.inv(new_c2w).astype(np.float32)
            keep_mask = np.ones(num_points, dtype=bool)

            do_removal = (
                    object_processor is not None and
                    object_ids_to_remove is not None and len(object_ids_to_remove) > 0 and
                    frame_idx_int is not None
            )

            if do_removal:
                frame_key = f"{frame_idx_int * 3:06d}.all_object_info.json"
                all_frame_objects = object_processor.get_object_info().get(frame_key, {})
                if not isinstance(all_frame_objects, dict) or len(all_frame_objects) == 0:
                    do_removal = False

            if do_removal:
                pts_world_filtered, cols_filtered, cam_ids_filtered = self._apply_hole_carving(
                    points_world=points_world,
                    colors=colors,
                    camera_ids=np.asarray(camera_ids),
                    K=K,
                    camera_to_world=new_c2w,
                    H=H,
                    W=W,
                    object_processor=object_processor,
                    frame_idx_int=frame_idx_int,
                    object_ids_to_remove=object_ids_to_remove,
                    scale_factor=scale_factor,
                    r_px=r_px,
                    clip_config=clip_config,
                    insertion_info=insertion_info,
                )
            else:
                pts_world_filtered = points_world
                cols_filtered = colors
                cam_ids_filtered = camera_ids

            final_rgb, final_depth, camera_source_mask = self._render_layered(
                points_world=pts_world_filtered,
                colors=cols_filtered,
                point_labels=cam_ids_filtered,
                target_cam_name=cam_name,
                K=K,
                camera_to_world=new_c2w,
                H=H,
                W=W,
            )
            results.append(
                PointCloudProcessor.NovelViewData(
                    cam_name=cam_name,
                    rgb_image=final_rgb,
                    depth_image=final_depth,
                    K=K,
                    camera_to_world=new_c2w,
                    camera_source_mask=camera_source_mask,
                )
            )
        return results

    def _composite_hybrid_views(
            self,
            lidar_views: List['PointCloudProcessor.NovelViewData'],
            rig: 'CameraRigSpec',
            object_processor: 'ObjectInfoProcessor',
            frame_idx_int: int,
            objects_to_add_info: List[Tuple[str, Path]],
            clip_config: 'ClipConfig',
            insertion_info: Optional[Dict[str, Any]] = None,
    ) -> List['PointCloudProcessor.NovelViewData']:
        """
        执行 Pass 2 (3DGS) 和 Pass 3 (合成)。
        """

        frame_key = f"{int(frame_idx_int) * 3:06d}.all_object_info.json"
        all_frame_objects: Dict[str, Any] = object_processor.get_object_info().get(frame_key, {})  # type: ignore
        if not isinstance(all_frame_objects, dict):
            return lidar_views

        # 组装 3DGS 渲染输入（在此处统一注入插入 / 重定位等编辑）
        objects_to_render: List[Dict[str, Any]] = self._prepare_objects_to_render(
            clip_config=clip_config,
            all_frame_objects=all_frame_objects,
            objects_to_add_info=objects_to_add_info,
            insertion_info=insertion_info,
            frame_idx_int=frame_idx_int,
        )

        if len(objects_to_render) == 0:
            return lidar_views

        # 判断当前是否存在 insertion 且配置为不挖洞，用于控制 3DGS 与背景的混合策略
        edit_config_local = EditProcessor.extract_edit_config(clip_config)
        has_insertion = edit_config_local.insertion_new_id is not None
        insertion_do_hole: bool = True
        validation_edit_info = getattr(clip_config, "validation_edit_info", {}) or {}
        if isinstance(validation_edit_info, dict):
            insertion_cfg = validation_edit_info.get("insertion", {}) or {}
            if isinstance(insertion_cfg, dict):
                insertion_do_hole = bool(insertion_cfg.get("do_hole", True))
        insertion_present_no_hole = has_insertion and (not insertion_do_hole)

        final_composited_views: List[PointCloudProcessor.NovelViewData] = []
        for lidar_view in lidar_views:
            cam_name = lidar_view.cam_name
            K = lidar_view.K
            H, W = lidar_view.rgb_image.shape[:2]
            c2w = lidar_view.camera_to_world
            w2c = np.linalg.inv(c2w).astype(np.float32)

            try:
                dist_coeffs = rig.get_dist_coeffs(cam_name)
            except Exception:
                dist_coeffs = np.zeros(5, dtype=np.float32)

            # Pass 2: 3DGS 渲染
            fg_rgb_u8, fg_alpha_u8, fg_depth_f32 = Gaussian.render_objects_for_camera(
                device=self._device,
                camera_pose_w2c=w2c,
                K=K,
                dist_coeffs=dist_coeffs,
                W=W,
                H=H,
                objects_to_render=objects_to_render,
            )

            # Pass 3: 合成（深度测试 + alpha 混合）
            bg_rgb_f32 = lidar_view.rgb_image.astype(np.float32) / 255.0
            bg_depth_f32 = lidar_view.depth_image.astype(np.float32)

            fg_rgb_f32 = fg_rgb_u8.astype(np.float32) / 255.0
            fg_alpha_f32 = (fg_alpha_u8.astype(np.float32) / 255.0)[..., None]

            fg_has_depth = fg_depth_f32 > 0
            bg_has_depth = bg_depth_f32 > 0
            hole_mask = ~bg_has_depth
            # 默认仅在挖洞区域进行混合；
            # 若存在 insertion 且配置为不挖洞，则对 insertion 采用覆盖式混合（fg_has_depth）。
            if insertion_present_no_hole:
                draw_fg_mask = fg_has_depth
            else:
                draw_fg_mask = fg_has_depth & hole_mask

            composited_rgb_f32 = bg_rgb_f32.copy()
            composited_depth_f32 = bg_depth_f32.copy()

            if np.any(draw_fg_mask):
                a = fg_alpha_f32[draw_fg_mask]
                fg_c = fg_rgb_f32[draw_fg_mask]
                bg_c = bg_rgb_f32[draw_fg_mask]
                composited_rgb_f32[draw_fg_mask] = fg_c * a + bg_c * (1.0 - a)
                composited_depth_f32[draw_fg_mask] = fg_depth_f32[draw_fg_mask]

            composited_rgb_u8 = (np.clip(composited_rgb_f32, 0.0, 1.0) * 255).astype(np.uint8)

            if lidar_view.camera_source_mask is not None:
                camera_source_mask = lidar_view.camera_source_mask.copy()
            else:
                H, W = lidar_view.rgb_image.shape[:2]
                camera_source_mask = np.zeros((H, W), dtype=np.uint8)
                camera_source_mask[bg_depth_f32 > 0] = 255

            if np.any(draw_fg_mask):
                camera_source_mask[draw_fg_mask] = 80  # 80 for the foreground object
            
            final_composited_views.append(
                PointCloudProcessor.NovelViewData(
                    cam_name=cam_name,
                    rgb_image=composited_rgb_u8,
                    depth_image=composited_depth_f32,
                    K=K,
                    camera_to_world=c2w,
                    camera_source_mask=camera_source_mask,
                )
            )

        return final_composited_views

    def process_file(
            self,
            pointcloud_path: Path,
            output_dir: Path,
            rig: 'CameraRigSpec',
            base_cam_c2w: Dict[str, np.ndarray],
            override_cam_c2w: Dict[str, np.ndarray],
            cams: Optional[List[str]] = None,
            front_cam_name: str = 'front',
            frame_idx_int: Optional[int] = None,
            object_processor: Optional['ObjectInfoProcessor'] = None,
            clip_config: Optional['ClipConfig'] = None,
            insertion_info: Optional[Dict[str, Any]] = None,
            visualize: bool = True,
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray]]:
        from .utils import load_point_cloud_efficient, POINT_CLOUD_STORAGE_FORMAT

        object_ids_to_remove = clip_config.object_ids_to_remove
        objects_to_add_info = clip_config.objects_to_add_info

        target_cams = rig.camera_names if (cams is None or cams == ['all']) else cams
        frame_id_str = f"{frame_idx_int:06d}"

        # 检查 3DGS 是否需要运行
        have_3dgs = (
                self._gs_renderer and
                object_processor is not None and
                objects_to_add_info is not None and len(objects_to_add_info) > 0 and
                frame_idx_int is not None
        )

        # Lidar 渲染路径
        pts_colors = load_point_cloud_efficient(
            pointcloud_path, format_type=POINT_CLOUD_STORAGE_FORMAT, return_labels=True
        )
        assert len(pts_colors) == 3, "Point cloud should contain points, colors, and camera IDs."
        points, colors, camera_ids = pts_colors
        if colors.max() > 1.0:
            colors = np.clip(colors / 255.0, 0.0, 1.0)

        if front_cam_name not in base_cam_c2w:
            raise RuntimeError('缺少 front 相机基线外参与点云坐标系转换。')
        front_c2w_base = base_cam_c2w[front_cam_name]
        cam_to_ego_front = rig.get_cam_to_ego(front_cam_name)
        ego_to_world = (front_c2w_base @ np.linalg.inv(cam_to_ego_front)).astype(np.float32)

        if override_cam_c2w is None or len(override_cam_c2w) == 0:
            raise RuntimeError('未提供 override_cam_c2w，配置驱动模式需要为各相机提供外参覆写。')

        # --- Pass 1: Lidar 背景渲染 ---
        points_world = self.transform.points_to_world(points, ego_to_world)
        novel_views_lidar = self._render_novel_views(
            points_world=points_world,
            colors=colors,
            camera_ids=np.asarray(camera_ids),
            rig=rig,
            override_cam_c2w=override_cam_c2w,
            target_cams=target_cams,
            object_processor=object_processor,
            frame_idx_int=frame_idx_int,
            object_ids_to_remove=object_ids_to_remove,
            clip_config=clip_config,
            insertion_info=insertion_info,
        )

        # --- Pass 2: 渲染 3DGS 并合成 ---
        final_views = novel_views_lidar
        if have_3dgs:
            final_views = self._composite_hybrid_views(
                lidar_views=novel_views_lidar,
                rig=rig,
                object_processor=object_processor,  # type: ignore
                frame_idx_int=frame_idx_int,  # type: ignore
                objects_to_add_info=objects_to_add_info,  # type: ignore
                clip_config=clip_config,  # type: ignore
                insertion_info=insertion_info,
            )

        # --- Pass 3: 保存最终合成结果 ---
        rgb_frames_dict: Dict[str, np.ndarray] = {}
        mask_frames_dict: Dict[str, np.ndarray] = {}
        for data in final_views:
            if visualize:
                self.saver.save_rgb_jpg(output_dir / data.cam_name, data.rgb_image, frame_id_str)
            
            if data.camera_source_mask is not None:
                mask_img = data.camera_source_mask.copy()
                mask_img[data.depth_image == 0] = 0
            else:
                mask_img = (data.depth_image > 0).astype(np.uint8) * 255
            
            # 反向mask：255 - mask
            mask_img = 255 - mask_img
            
            if visualize:
                self.saver.save_mask_png(output_dir / data.cam_name, mask_img, frame_id_str)
            rgb_frames_dict[data.cam_name] = data.rgb_image.copy()
            mask_frames_dict[data.cam_name] = mask_img.copy()
        
        # --- Pass 4: 计算3dbox的2d投影区域（在点云删除逻辑后） ---
        area_frames_dict = self._compute_object_2d_bboxes(
            rig=rig,
            camera_to_world=override_cam_c2w,
            object_processor=object_processor,
            frame_idx_int=frame_idx_int,
            object_ids_to_remove=object_ids_to_remove,
            objects_to_add_info=objects_to_add_info,
            scale_factor=1.2,
        )
        
        return rgb_frames_dict, mask_frames_dict, area_frames_dict